Commit 1783b652 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/gemm_tile_loop

parents 4e755e73 a66d14ed
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney
*.rst @saadrahim @LisaDelaney
# Header directory
library/include/* @saadrahim @LisaDelaney
...@@ -15,6 +15,12 @@ if (DTYPES) ...@@ -15,6 +15,12 @@ if (DTYPES)
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8) add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON") set(CK_ENABLE_FP8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif()
if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16) add_definitions(-DCK_ENABLE_FP16)
...@@ -34,8 +40,9 @@ if (DTYPES) ...@@ -34,8 +40,9 @@ if (DTYPES)
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_ALL_DTYPES "ON")
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
endif() endif()
if(DL_KERNELS) if(DL_KERNELS)
...@@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ...@@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
#message("fp8 instance found!") #message("fp8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8")
#message("bf8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!") #message("fp16 instance found!")
set(add_inst 1) set(add_inst 1)
......
...@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){ ...@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){
} else{ } else{
setup_args = ' -DBUILD_DEV=On' + setup_args setup_args = ' -DBUILD_DEV=On' + setup_args
} }
if (params.DL_KERNELS){
setup_args = setup_args + " -DDL_KERNELS=ON "
}
if(build_type_debug){ if(build_type_debug){
setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args
...@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){ ...@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS') timeout(time: 24, unit: 'HOURS')
{ {
//cmake_build(conf)
//instead of building, just unstash the ckProfiler and install it
sh """ sh """
rm -rf build rm -rf build
mkdir build mkdir build
...@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){ ...@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1
0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
agent none agent none
...@@ -649,6 +650,10 @@ pipeline { ...@@ -649,6 +650,10 @@ pipeline {
name: "RUN_FULL_QA", name: "RUN_FULL_QA",
defaultValue: false, defaultValue: false,
description: "Select whether to run small set of performance tests (default) or full QA") description: "Select whether to run small set of performance tests (default) or full QA")
booleanParam(
name: "DL_KERNELS",
defaultValue: false,
description: "Select whether to build DL kernels (default: OFF)")
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
...@@ -663,15 +668,12 @@ pipeline { ...@@ -663,15 +668,12 @@ pipeline {
} }
stages{ stages{
stage("Build Docker"){ stage("Build Docker"){
//when {
// beforeAgent true
// expression { params.BUILD_DOCKER.toBoolean() }
//}
parallel{ parallel{
stage('Docker /opt/rocm'){ stage('Docker /opt/rocm'){
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
steps{ steps{
buildDocker('/opt/rocm') buildDocker('/opt/rocm')
cleanWs()
} }
} }
} }
...@@ -693,6 +695,7 @@ pipeline { ...@@ -693,6 +695,7 @@ pipeline {
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
cleanWs()
} }
} }
} }
...@@ -715,6 +718,7 @@ pipeline { ...@@ -715,6 +718,7 @@ pipeline {
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on MI100/MI200") stage("Build CK and run Tests on MI100/MI200")
...@@ -730,6 +734,7 @@ pipeline { ...@@ -730,6 +734,7 @@ pipeline {
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on Navi21") stage("Build CK and run Tests on Navi21")
...@@ -742,10 +747,10 @@ pipeline { ...@@ -742,10 +747,10 @@ pipeline {
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on Navi32") stage("Build CK and run Tests on Navi32")
...@@ -756,12 +761,12 @@ pipeline { ...@@ -756,12 +761,12 @@ pipeline {
} }
agent{ label rocmnode("navi32") } agent{ label rocmnode("navi32") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
} }
...@@ -784,6 +789,7 @@ pipeline { ...@@ -784,6 +789,7 @@ pipeline {
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
} }
} }
stage("Run ckProfiler: gfx90a") stage("Run ckProfiler: gfx90a")
...@@ -799,6 +805,7 @@ pipeline { ...@@ -799,6 +805,7 @@ pipeline {
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
} }
} }
} }
...@@ -811,6 +818,7 @@ pipeline { ...@@ -811,6 +818,7 @@ pipeline {
agent { label 'mici' } agent { label 'mici' }
steps{ steps{
process_results() process_results()
cleanWs()
} }
} }
} }
......
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
endif()
...@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) ...@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
endif() endif()
endif() endif()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
endif()
...@@ -14,18 +14,22 @@ using ComputeDataType = float; ...@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp struct YElementOp
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value || static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<T, ck::half_t>::value, ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T a; static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x); ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a; y = ck::type_convert<Y>(x * a);
}; };
}; };
......
...@@ -43,6 +43,9 @@ ...@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8 #ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON" #define CK_ENABLE_FP8 "ON"
#endif #endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON" #define CK_ENABLE_FP16 "ON"
#endif #endif
...@@ -66,6 +69,10 @@ ...@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif #endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif #endif
......
...@@ -168,7 +168,8 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout, ...@@ -168,7 +168,8 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(ck::get_device_name() == "gfx1030") if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemm::CheckValidity(karg); return GridwiseGemm::CheckValidity(karg);
} }
......
...@@ -144,7 +144,8 @@ template <typename ALayout, ...@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
......
...@@ -27,6 +27,12 @@ struct PassThrough ...@@ -27,6 +27,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -81,6 +87,12 @@ struct PassThrough ...@@ -81,6 +87,12 @@ struct PassThrough
y = type_convert<int8_t>(x); y = type_convert<int8_t>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const __host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
...@@ -89,6 +101,7 @@ struct PassThrough ...@@ -89,6 +101,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -118,6 +131,7 @@ struct PassThrough ...@@ -118,6 +131,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -146,6 +160,7 @@ struct ConvertBF16RTN ...@@ -146,6 +160,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -162,6 +177,7 @@ struct ConvertF8SR ...@@ -162,6 +177,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -412,14 +428,19 @@ struct Swish ...@@ -412,14 +428,19 @@ struct Swish
{ {
Swish(float beta = 1.0f) : beta_(beta) {} Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<T, ck::half_t>::value, is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x)); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; float beta_ = 1.0f;
......
...@@ -28,7 +28,8 @@ __global__ void ...@@ -28,7 +28,8 @@ __global__ void
#endif #endif
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
......
...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert dst_vector.template AsType<DstData>()(i) = v;
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = v;
}); });
}); });
} }
......
...@@ -11,8 +11,14 @@ namespace ck { ...@@ -11,8 +11,14 @@ namespace ck {
enum struct DppInstr enum struct DppInstr
{ {
dpp8_f16_16x16x2 = 0, dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2, dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2 dpp8_f16_32x8x2
}; };
...@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2> ...@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
} }
}; };
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <> template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2> struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{ {
...@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2> ...@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
} }
}; };
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp> template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector struct DppSelector
{ {
...@@ -143,6 +329,12 @@ struct DppSelector ...@@ -143,6 +329,12 @@ struct DppSelector
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() static constexpr auto GetDpp<half_t, 16, 16>()
{ {
...@@ -155,6 +347,36 @@ struct DppSelector ...@@ -155,6 +347,36 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{}; static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector() __host__ __device__ constexpr DppSelector()
...@@ -191,7 +413,6 @@ struct DppSelector ...@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized. // in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a); static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1); static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup == static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size); selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
...@@ -215,11 +436,6 @@ struct DppGemm ...@@ -215,11 +436,6 @@ struct DppGemm
__host__ __device__ constexpr DppGemm() __host__ __device__ constexpr DppGemm()
{ {
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
"MPerDpp must be either 8, 16 or 32.");
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
"NPerDpp must be either 8, 16 or 32.");
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
} }
......
...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -640,6 +642,7 @@ struct MfmaSelector ...@@ -640,6 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -651,6 +654,7 @@ struct MfmaSelector ...@@ -651,6 +654,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -852,7 +856,11 @@ struct XdlopsGemm ...@@ -852,7 +856,11 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
} }
#endif #endif
#endif
} }
// buffer_load requires: // buffer_load requires:
...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
} }
#endif #endif
} }
......
...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,7 +12,12 @@ using half_t = _Float16; ...@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t; #if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -143,14 +148,24 @@ struct scalar_type<int4_t> ...@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
//
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t> ...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000 static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); } __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); } __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); } __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T>
struct NumericUtils
{
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck } // namespace ck
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -22,53 +23,38 @@ namespace ck::utils { ...@@ -22,53 +23,38 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int out_mant = NumericUtils<Y>::mant;
// fp8 exponent/mantissa layout // original type exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int f8_mant = 3; constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
if constexpr(is_float) head = x_bitwise & NumericUtils<X>::head_mask;
{ mantissa = x_bitwise & NumericUtils<X>::mant_mask;
head = x_bitwise & 0xFF800000; exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
mantissa = x_bitwise & 0x7FFFFF; sign = head >> (in_exp + in_mant);
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant); uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
} uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
else if constexpr(is_half) constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0 // check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant; mantissa += 1 << in_mant;
// apply random number if needed // apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant)) if(mantissa >= (2 << in_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (type_mant - f8_mant); mantissa >>= (in_mant - out_mant);
// check negative exponent // check negative exponent
if(exponent <= 0) if(exponent <= 0)
...@@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << f8_mant) - 1; mantissa = (1 << out_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0 // check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << f8_mant) - 1; mantissa &= (1 << out_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int in_exp = NumericUtils<X>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int in_mant = NumericUtils<X>::mant;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes // prepare the codes
constexpr uint8_t nan_code = 0x80; constexpr X nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0; Y Inf, NegInf, NaN, Neg0;
if constexpr(is_half) using T_bitwise = typename NumericUtils<Y>::bitwise_type;
{
constexpr uint16_t ihInf = 0x7C00; constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr uint16_t ihNegInf = 0xFC00; constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr uint16_t ihNaN = 0x7C01; constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr uint16_t ihNeg0 = 0x8000; constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
} Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
else if constexpr(is_float)
{ // check if x is 0.0
constexpr uint32_t ifInf = 0x7F800000; if(x == 0)
constexpr uint32_t ifNegInf = 0xFF800000; return static_cast<Y>(0);
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; T_bitwise retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == nan_code) if(x == nan_code)
return fNaN; return NaN;
} }
else else
{ {
if(x == nan_code) if(x == nan_code)
return fNeg0; return Neg0;
if(exponent == ((1 << f8_exp) - 1)) if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
} }
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); exponent++;
mantissa <<= sh; while(mantissa < (1 << in_mant))
mantissa &= ((1 << f8_mant) - 1); {
exponent += 1 - sh; mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << in_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << type_mant; mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype // check datatypes
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8."); static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng); return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ Y cast_from_f8(X x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0 return run_cast_from_f8<X, Y, negative_zero_nan>(x);
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck::utils } // namespace ck::utils
#endif
...@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...@@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) ...@@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
} }
// convert fp8 to fp32 // convert fp8 to fp32
...@@ -97,7 +99,7 @@ template <> ...@@ -97,7 +99,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x); return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
} }
// convert fp16 to fp8 // convert fp16 to fp8
...@@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
} }
// convert fp8 to fp16 // convert fp8 to fp16
...@@ -117,8 +120,53 @@ template <> ...@@ -117,8 +120,53 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
} }
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
}
#endif
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
...@@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...@@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
...@@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
} }
#endif
} // namespace ck } // namespace ck
...@@ -20,7 +20,8 @@ template <typename ADataType, ...@@ -20,7 +20,8 @@ template <typename ADataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputType = ADataType>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ADataType v_a; ComputType v_a;
BDataType v_b; ComputType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment