Commit 1783b652 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/gemm_tile_loop

parents 4e755e73 a66d14ed
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney
*.rst @saadrahim @LisaDelaney
# Header directory
library/include/* @saadrahim @LisaDelaney
......@@ -15,6 +15,12 @@ if (DTYPES)
if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif()
if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16)
......@@ -34,8 +40,9 @@ if (DTYPES)
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON")
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
endif()
if(DL_KERNELS)
......@@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
#message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8")
#message("bf8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
set(add_inst 1)
......
......@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){
} else{
setup_args = ' -DBUILD_DEV=On' + setup_args
}
if (params.DL_KERNELS){
setup_args = setup_args + " -DDL_KERNELS=ON "
}
if(build_type_debug){
setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args
......@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS')
{
//cmake_build(conf)
//instead of building, just unstash the ckProfiler and install it
sh """
rm -rf build
mkdir build
......@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1
0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline {
agent none
......@@ -649,6 +650,10 @@ pipeline {
name: "RUN_FULL_QA",
defaultValue: false,
description: "Select whether to run small set of performance tests (default) or full QA")
booleanParam(
name: "DL_KERNELS",
defaultValue: false,
description: "Select whether to build DL kernels (default: OFF)")
}
environment{
dbuser = "${dbuser}"
......@@ -663,15 +668,12 @@ pipeline {
}
stages{
stage("Build Docker"){
//when {
// beforeAgent true
// expression { params.BUILD_DOCKER.toBoolean() }
//}
parallel{
stage('Docker /opt/rocm'){
agent{ label rocmnode("nogpu") }
steps{
buildDocker('/opt/rocm')
cleanWs()
}
}
}
......@@ -693,6 +695,7 @@ pipeline {
}
steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
cleanWs()
}
}
}
......@@ -715,6 +718,7 @@ pipeline {
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
stage("Build CK and run Tests on MI100/MI200")
......@@ -730,6 +734,7 @@ pipeline {
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
stage("Build CK and run Tests on Navi21")
......@@ -742,10 +747,10 @@ pipeline {
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
stage("Build CK and run Tests on Navi32")
......@@ -756,12 +761,12 @@ pipeline {
}
agent{ label rocmnode("navi32") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
}
......@@ -784,6 +789,7 @@ pipeline {
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
}
}
stage("Run ckProfiler: gfx90a")
......@@ -799,6 +805,7 @@ pipeline {
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
}
}
}
......@@ -811,6 +818,7 @@ pipeline {
agent { label 'mici' }
steps{
process_results()
cleanWs()
}
}
}
......
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
endif()
......@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
endif()
endif()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
endif()
......@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
ck::is_same<T, ck::half_t>::value,
static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
T a;
static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a;
y = ck::type_convert<Y>(x * a);
};
};
......
......@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
......@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
......
......@@ -168,7 +168,8 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::get_device_name() == "gfx1030")
if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{
return GridwiseGemm::CheckValidity(karg);
}
......
......@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
......@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
ComputeDataType,
AccDataType,
......
......@@ -27,6 +27,12 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -81,6 +87,12 @@ struct PassThrough
y = type_convert<int8_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
......@@ -89,6 +101,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
......@@ -118,6 +131,7 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
};
struct UnaryConvert
......@@ -146,6 +160,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
......@@ -162,6 +177,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{
......@@ -412,14 +428,19 @@ struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
......
......@@ -28,7 +28,8 @@ __global__ void
#endif
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
......
......@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
dst_vector.template AsType<DstData>()(i) = v;
});
const bool is_dst_valid =
......@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = v;
});
});
}
......
......@@ -11,8 +11,14 @@ namespace ck {
enum struct DppInstr
{
dpp8_f16_16x16x2 = 0,
dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2
};
......@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{
......@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
......@@ -143,6 +329,12 @@ struct DppSelector
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
{
......@@ -155,6 +347,36 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
......@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
......@@ -215,11 +436,6 @@ struct DppGemm
__host__ __device__ constexpr DppGemm()
{
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
"MPerDpp must be either 8, 16 or 32.");
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
"NPerDpp must be either 8, 16 or 32.");
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
}
......
......@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
......@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
......@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
......@@ -651,6 +654,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
......@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
......@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
......@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
#else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
......@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
}
#endif
#endif
}
// buffer_load requires:
......@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp =
......@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
#else
if(dst_thread_element_valid)
{
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
......@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
}
#endif
}
......
......@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
} // namespace ck
#endif
......@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
using f8_t = uint8_t;
#if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type
template <typename T, index_t N>
......@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
//
template <typename T>
struct vector_type<T, 1>
{
......@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T>
struct NumericLimits
......@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <>
struct NumericLimits<f8_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
#endif
template <typename T>
struct NumericUtils
{
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck
......@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck {
// fp8 rounding modes
......@@ -22,53 +23,38 @@ namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int out_mant = NumericUtils<Y>::mant;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// original type exponent/mantissa layout
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
int exponent;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
T_bitwise;
using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
if constexpr(is_float)
{
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
head = x_bitwise & NumericUtils<X>::head_mask;
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
......@@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant;
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << in_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant))
if(mantissa >= (2 << in_mant))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (type_mant - f8_mant);
mantissa >>= (in_mant - out_mant);
// check negative exponent
if(exponent <= 0)
......@@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
if(clip)
{
mantissa = (1 << f8_mant) - 1;
mantissa = (1 << out_mant) - 1;
exponent = max_exp;
}
else
......@@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << f8_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y run_cast_from_f8(X x)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes
constexpr uint8_t nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
constexpr X nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant;
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return fNaN;
return NaN;
}
else
{
if(x == nan_code)
return fNeg0;
if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh;
mantissa &= ((1 << f8_mant) - 1);
exponent += 1 - sh;
exponent++;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant;
mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << type_mant;
mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval));
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y cast_from_f8(X x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
}
} // namespace ck::utils
#endif
......@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
......@@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}
// convert fp8 to fp32
......@@ -97,7 +99,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x);
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
}
// convert fp16 to fp8
......@@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp8 to fp16
......@@ -117,8 +120,53 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
}
#endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
......@@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
......@@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}
// convert fp16 to fp8 with stochastic rounding
......@@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
#endif
} // namespace ck
......@@ -20,7 +20,8 @@ template <typename ADataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputType = ADataType>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
......@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ComputType v_a;
ComputType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment