Commit dcf15456 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' into lwpck-911

parents 1e3eb1c6 5fe687fa
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney
*.rst @saadrahim @LisaDelaney
# Header directory
library/include/* @saadrahim @LisaDelaney
cmake_minimum_required(VERSION 3.14) cmake_minimum_required(VERSION 3.14)
# This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE )
set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." )
endif()
# Default installation path
if(WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "")
else()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
endif()
set(version 1.1.0) set(version 1.1.0)
# Check support for CUDA/HIP in Cmake # Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version}) project(composable_kernel VERSION ${version})
...@@ -15,6 +28,12 @@ if (DTYPES) ...@@ -15,6 +28,12 @@ if (DTYPES)
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8) add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON") set(CK_ENABLE_FP8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif()
if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16) add_definitions(-DCK_ENABLE_FP16)
...@@ -34,8 +53,9 @@ if (DTYPES) ...@@ -34,8 +53,9 @@ if (DTYPES)
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_ALL_DTYPES "ON")
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
endif() endif()
if(DL_KERNELS) if(DL_KERNELS)
...@@ -365,6 +385,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ...@@ -365,6 +385,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
#message("fp8 instance found!") #message("fp8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8")
#message("bf8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!") #message("fp16 instance found!")
set(add_inst 1) set(add_inst 1)
......
...@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){ ...@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){
} else{ } else{
setup_args = ' -DBUILD_DEV=On' + setup_args setup_args = ' -DBUILD_DEV=On' + setup_args
} }
if (params.DL_KERNELS){
setup_args = setup_args + " -DDL_KERNELS=ON "
}
if(build_type_debug){ if(build_type_debug){
setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args
...@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){ ...@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS') timeout(time: 24, unit: 'HOURS')
{ {
//cmake_build(conf)
//instead of building, just unstash the ckProfiler and install it
sh """ sh """
rm -rf build rm -rf build
mkdir build mkdir build
...@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){ ...@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1
0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
agent none agent none
...@@ -649,6 +650,10 @@ pipeline { ...@@ -649,6 +650,10 @@ pipeline {
name: "RUN_FULL_QA", name: "RUN_FULL_QA",
defaultValue: false, defaultValue: false,
description: "Select whether to run small set of performance tests (default) or full QA") description: "Select whether to run small set of performance tests (default) or full QA")
booleanParam(
name: "DL_KERNELS",
defaultValue: false,
description: "Select whether to build DL kernels (default: OFF)")
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
...@@ -663,15 +668,12 @@ pipeline { ...@@ -663,15 +668,12 @@ pipeline {
} }
stages{ stages{
stage("Build Docker"){ stage("Build Docker"){
//when {
// beforeAgent true
// expression { params.BUILD_DOCKER.toBoolean() }
//}
parallel{ parallel{
stage('Docker /opt/rocm'){ stage('Docker /opt/rocm'){
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
steps{ steps{
buildDocker('/opt/rocm') buildDocker('/opt/rocm')
cleanWs()
} }
} }
} }
...@@ -693,6 +695,7 @@ pipeline { ...@@ -693,6 +695,7 @@ pipeline {
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
cleanWs()
} }
} }
} }
...@@ -715,6 +718,7 @@ pipeline { ...@@ -715,6 +718,7 @@ pipeline {
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on MI100/MI200") stage("Build CK and run Tests on MI100/MI200")
...@@ -730,6 +734,7 @@ pipeline { ...@@ -730,6 +734,7 @@ pipeline {
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on Navi21") stage("Build CK and run Tests on Navi21")
...@@ -742,10 +747,10 @@ pipeline { ...@@ -742,10 +747,10 @@ pipeline {
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
stage("Build CK and run Tests on Navi32") stage("Build CK and run Tests on Navi32")
...@@ -756,12 +761,12 @@ pipeline { ...@@ -756,12 +761,12 @@ pipeline {
} }
agent{ label rocmnode("navi32") } agent{ label rocmnode("navi32") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
} }
} }
} }
...@@ -784,6 +789,7 @@ pipeline { ...@@ -784,6 +789,7 @@ pipeline {
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
} }
} }
stage("Run ckProfiler: gfx90a") stage("Run ckProfiler: gfx90a")
...@@ -799,6 +805,7 @@ pipeline { ...@@ -799,6 +805,7 @@ pipeline {
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
} }
} }
} }
...@@ -811,6 +818,7 @@ pipeline { ...@@ -811,6 +818,7 @@ pipeline {
agent { label 'mici' } agent { label 'mici' }
steps{ steps{
process_results() process_results()
cleanWs()
} }
} }
} }
......
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
endif()
...@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) ...@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
endif() endif()
endif() endif()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
endif()
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
using InDataType = F16; using InDataType = F16;
using WeiDataType = F16; using WeiDataType = F16;
...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough; ...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough;
using OutElementOp = PassThrough; using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< NDimSpatial, // NDimSpatial
NDimSpatial, // NDimSpatial ck::tuple_element_t<NDimSpatial - 1,
InDataType, // InDataType ck::Tuple<ck::tensor_layout::convolution::GNWC,
WeiDataType, // WeiDataType ck::tensor_layout::convolution::GNHWC,
OutDataType, // OutDataType ck::tensor_layout::convolution::GNDHWC>>, // InLayout
AccDataType, // AccDataType ck::tuple_element_t<NDimSpatial - 1,
InElementOp, // InElementwiseOperation ck::Tuple<ck::tensor_layout::convolution::GKXC,
WeiElementOp, // WeiElementwiseOperation ck::tensor_layout::convolution::GKYXC,
OutElementOp, // OutElementwiseOperation ck::tensor_layout::convolution::GKZYXC>>, // WeiLayout
ConvBwdWeightDefault, // ConvBackwardWeightSpecialization ck::tuple_element_t<NDimSpatial - 1,
256, // BlockSize ck::Tuple<ck::tensor_layout::convolution::GNWK,
128, // MPerBlock ck::tensor_layout::convolution::GNHWK,
128, // NPerBlock ck::tensor_layout::convolution::GNDHWK>>, // OutLayout
16, // K0PerBlock InDataType, // InDataType
2, // K1 WeiDataType, // WeiDataType
4, // M1PerThread OutDataType, // OutDataType
4, // N1PerThread AccDataType, // AccDataType
1, // KPerThread InElementOp, // InElementwiseOperation
S<8, 2>, // M1N1ThreadClusterM1Xs WeiElementOp, // WeiElementwiseOperation
S<8, 2>, // M1N1ThreadClusterN1Xs OutElementOp, // OutElementwiseOperation
S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 ConvBwdWeightDefault, // ConvBackwardWeightSpecialization
S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 256, // BlockSize
S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder 128, // MPerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder 128, // NPerBlock
S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 16, // K0PerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder 2, // K1
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 4, // M1PerThread
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 4, // N1PerThread
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 1, // KPerThread
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder S<8, 2>, // M1N1ThreadClusterM1Xs
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder S<8, 2>, // M1N1ThreadClusterN1Xs
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
5, // CThreadTransferSrcDstVectorDim S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
4>; // CThreadTransferDstScalarPerVector S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
......
...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial> ...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
ck::index_t split_k;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1 // Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support constexpr ck::index_t split_k = 1;
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{
split_k = 2;
}
else
{
split_k = 1;
}
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
......
...@@ -14,18 +14,22 @@ using ComputeDataType = float; ...@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp struct YElementOp
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value || static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<T, ck::half_t>::value, ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T a; static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x); ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a; y = ck::type_convert<Y>(x * a);
}; };
}; };
......
...@@ -43,6 +43,9 @@ ...@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8 #ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON" #define CK_ENABLE_FP8 "ON"
#endif #endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON" #define CK_ENABLE_FP16 "ON"
#endif #endif
...@@ -66,6 +69,10 @@ ...@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif #endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif #endif
......
...@@ -144,7 +144,8 @@ template <typename ALayout, ...@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
......
...@@ -27,6 +27,12 @@ struct PassThrough ...@@ -27,6 +27,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -81,6 +87,12 @@ struct PassThrough ...@@ -81,6 +87,12 @@ struct PassThrough
y = type_convert<int8_t>(x); y = type_convert<int8_t>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const __host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
...@@ -89,6 +101,7 @@ struct PassThrough ...@@ -89,6 +101,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -118,6 +131,7 @@ struct PassThrough ...@@ -118,6 +131,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -146,6 +160,7 @@ struct ConvertBF16RTN ...@@ -146,6 +160,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -162,6 +177,7 @@ struct ConvertF8SR ...@@ -162,6 +177,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -412,14 +428,19 @@ struct Swish ...@@ -412,14 +428,19 @@ struct Swish
{ {
Swish(float beta = 1.0f) : beta_(beta) {} Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<T, ck::half_t>::value, is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x)); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; float beta_ = 1.0f;
......
...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert dst_vector.template AsType<DstData>()(i) = v;
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = v;
}); });
}); });
} }
......
...@@ -11,8 +11,14 @@ namespace ck { ...@@ -11,8 +11,14 @@ namespace ck {
enum struct DppInstr enum struct DppInstr
{ {
dpp8_f16_16x16x2 = 0, dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2, dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2 dpp8_f16_32x8x2
}; };
...@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2> ...@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
} }
}; };
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <> template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2> struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{ {
...@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2> ...@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
} }
}; };
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp> template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector struct DppSelector
{ {
...@@ -143,6 +329,12 @@ struct DppSelector ...@@ -143,6 +329,12 @@ struct DppSelector
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() static constexpr auto GetDpp<half_t, 16, 16>()
{ {
...@@ -155,6 +347,36 @@ struct DppSelector ...@@ -155,6 +347,36 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{}; static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector() __host__ __device__ constexpr DppSelector()
...@@ -191,7 +413,6 @@ struct DppSelector ...@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized. // in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a); static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1); static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup == static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size); selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
...@@ -215,11 +436,6 @@ struct DppGemm ...@@ -215,11 +436,6 @@ struct DppGemm
__host__ __device__ constexpr DppGemm() __host__ __device__ constexpr DppGemm()
{ {
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
"MPerDpp must be either 8, 16 or 32.");
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
"NPerDpp must be either 8, 16 or 32.");
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
} }
......
...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -640,6 +642,7 @@ struct MfmaSelector ...@@ -640,6 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -651,6 +654,7 @@ struct MfmaSelector ...@@ -651,6 +654,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -852,7 +856,11 @@ struct XdlopsGemm ...@@ -852,7 +856,11 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
} }
#endif #endif
#endif
} }
// buffer_load requires: // buffer_load requires:
...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
} }
#endif #endif
} }
......
...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,7 +12,12 @@ using half_t = _Float16; ...@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t; #if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -143,14 +148,24 @@ struct scalar_type<int4_t> ...@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
//
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t> ...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000 static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); } __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); } __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); } __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T>
struct NumericUtils
{
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck } // namespace ck
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
// these conversions are disabled if native conversions available // these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -24,53 +25,38 @@ namespace ck::utils { ...@@ -24,53 +25,38 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int out_mant = NumericUtils<Y>::mant;
// fp8 exponent/mantissa layout // original type exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int f8_mant = 3; constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
if constexpr(is_float) head = x_bitwise & NumericUtils<X>::head_mask;
{ mantissa = x_bitwise & NumericUtils<X>::mant_mask;
head = x_bitwise & 0xFF800000; exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
mantissa = x_bitwise & 0x7FFFFF; sign = head >> (in_exp + in_mant);
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant); uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
} uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
else if constexpr(is_half) constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -83,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -83,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0 // check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant; mantissa += 1 << in_mant;
// apply random number if needed // apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant)) if(mantissa >= (2 << in_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (type_mant - f8_mant); mantissa >>= (in_mant - out_mant);
// check negative exponent // check negative exponent
if(exponent <= 0) if(exponent <= 0)
...@@ -118,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -118,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << f8_mant) - 1; mantissa = (1 << out_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -129,125 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -129,125 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0 // check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << f8_mant) - 1; mantissa &= (1 << out_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int in_exp = NumericUtils<X>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int in_mant = NumericUtils<X>::mant;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes // prepare the codes
constexpr uint8_t nan_code = 0x80; constexpr X nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0; Y Inf, NegInf, NaN, Neg0;
if constexpr(is_half) using T_bitwise = typename NumericUtils<Y>::bitwise_type;
{
constexpr uint16_t ihInf = 0x7C00; constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr uint16_t ihNegInf = 0xFC00; constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr uint16_t ihNaN = 0x7C01; constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr uint16_t ihNeg0 = 0x8000; constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
} Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
else if constexpr(is_float)
{ // check if x is 0.0
constexpr uint32_t ifInf = 0x7F800000; if(x == 0)
constexpr uint32_t ifNegInf = 0xFF800000; return static_cast<Y>(0);
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; T_bitwise retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == nan_code) if(x == nan_code)
return fNaN; return NaN;
} }
else else
{ {
if(x == nan_code) if(x == nan_code)
return fNeg0; return Neg0;
if(exponent == ((1 << f8_exp) - 1)) if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
} }
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); exponent++;
mantissa <<= sh; while(mantissa < (1 << in_mant))
mantissa &= ((1 << f8_mant) - 1); {
exponent += 1 - sh; mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << in_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << type_mant; mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype // check datatypes
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8."); static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng); return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ Y cast_from_f8(X x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0 return run_cast_from_f8<X, Y, negative_zero_nan>(x);
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck::utils } // namespace ck::utils
#endif #endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment