Commit de1afb7b authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into lwpck-977

parents ce562aa6 f7331c60
...@@ -2,44 +2,44 @@ function(add_instance_library INSTANCE_NAME) ...@@ -2,44 +2,44 @@ function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}") message("adding instance ${INSTANCE_NAME}")
set(result 1) set(result 1)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16") if(type MATCHES "fp16")
set(type1 "_f16") set(type1 "_f16")
elseif(type MATCHES "fp32") elseif(type MATCHES "fp32")
set(type1 "_f32") set(type1 "_f32")
elseif(type MATCHES "fp8") elseif(type MATCHES "fp8")
set(type1 "_f8") set(type1 "_f8")
elseif(type MATCHES "bf16") elseif(type MATCHES "bf16")
set(type1 "_b16") set(type1 "_b16")
elseif(type MATCHES "fp64") elseif(type MATCHES "fp64")
set(type1 "_f64") set(type1 "_f64")
elseif(type MATCHES "int8") elseif(type MATCHES "int8")
set(type1 "_i8") set(type1 "_i8")
endif() endif()
#make an exception for reduction kernels #make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance") if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column")
#if filename matches any selected type, exit type loop and do no exclude the file from the list #if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0) set(test 0)
break() break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1)) NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal #if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1) set(test 1)
endif() endif()
endforeach() endforeach()
if(test EQUAL 1) if(test EQUAL 1)
message("removing instance ${source} ") message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
endif() endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ") message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
...@@ -49,8 +49,10 @@ function(add_instance_library INSTANCE_NAME) ...@@ -49,8 +49,10 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME}) clang_tidy_check(${INSTANCE_NAME})
set(result 0) set(result 0)
message("add_instance_library ${INSTANCE_NAME}")
else()
message("skip_instance_libary ${INSTANCE_NAME}")
endif() endif()
#message("add_instance_library returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
...@@ -58,65 +60,70 @@ endfunction(add_instance_library INSTANCE_NAME) ...@@ -58,65 +60,70 @@ endfunction(add_instance_library INSTANCE_NAME)
file(GLOB dir_list LIST_DIRECTORIES true *) file(GLOB dir_list LIST_DIRECTORIES true *)
set(CK_DEVICE_INSTANCES) set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list}) FOREACH(subdir_path ${dir_list})
set(target_dir) set(target_dir)
IF(IS_DIRECTORY "${subdir_path}") IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance) set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0) set(add_inst 0)
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!") message("fp8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!") message("fp16 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!") message("fp32 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!") message("fp64 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16") if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!") message("bf16 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!") message("int8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(NOT "${cmake_instance}" MATCHES "_fp8" OR if(NOT ("${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR "${cmake_instance}" MATCHES "_int4"))
NOT DEFINED DTYPES) message("instance should be built for all types!")
message("instance should be built for all types!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(NOT DEFINED DTYPES)
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8") set(add_inst 1)
message("quantization instances will not be built!") endif()
set(add_inst 0) if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
endif() message("quantization instances will not be built!")
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(add_inst EQUAL 1) if(("${cmake_instance}" MATCHES "ONLY DL_KERNELS") AND (NOT DEFINED DL_KERNELS))
get_filename_component(target_dir ${subdir_path} NAME) message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
add_subdirectory(${target_dir}) set(add_inst 0)
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>) endif()
endif() if((add_inst EQUAL 1))
ENDIF() get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
message("add_instance_directory ${subdir_path}")
else()
message("skip_instance_directory ${subdir_path}")
endif()
ENDIF()
ENDFOREACH() ENDFOREACH()
add_library(device_operations STATIC ${CK_DEVICE_INSTANCES}) add_library(device_operations STATIC ${CK_DEVICE_INSTANCES})
...@@ -158,11 +165,11 @@ target_compile_options(device_operations PRIVATE ...@@ -158,11 +165,11 @@ target_compile_options(device_operations PRIVATE
# install(TARGETS device_operations LIBRARY DESTINATION lib) # install(TARGETS device_operations LIBRARY DESTINATION lib)
rocm_install(TARGETS device_operations rocm_install(TARGETS device_operations
EXPORT device_operationsTargets) EXPORT device_operationsTargets)
rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
rocm_install(EXPORT device_operationsTargets rocm_install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel:: NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
) )
...@@ -96,13 +96,9 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instan ...@@ -96,13 +96,9 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instan
add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
set(ENABLE_PIPELINE_V2_OPT OFF) set(ENABLE_PIPELINE_V2_OPT)
if (ENABLE_PIPELINE_V2_OPT) if (ENABLE_PIPELINE_V2_OPT)
set(MAX_ILP_OPTS
-mllvm
-amdgpu-enable-max-ilp-scheduling-strategy
)
set(WAVES_PER_EU_DEFS set(WAVES_PER_EU_DEFS
CK_USE_WAVES_PER_EU=1 CK_USE_WAVES_PER_EU=1
CK_MIN_WAVES_PER_EU=1 CK_MIN_WAVES_PER_EU=1
...@@ -118,7 +114,7 @@ if (ENABLE_PIPELINE_V2_OPT) ...@@ -118,7 +114,7 @@ if (ENABLE_PIPELINE_V2_OPT)
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=NN # layout=NN
set_source_files_properties(device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES set_source_files_properties(device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS "${MAX_ILP_OPTS}" COMPILE_OPTIONS ";;"
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=TT # layout=TT
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
...@@ -126,7 +122,7 @@ if (ENABLE_PIPELINE_V2_OPT) ...@@ -126,7 +122,7 @@ if (ENABLE_PIPELINE_V2_OPT)
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=TN # layout=TN
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS "${MAX_ILP_OPTS}" COMPILE_OPTIONS ";;"
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
endif(ENABLE_PIPELINE_V2_OPT) endif(ENABLE_PIPELINE_V2_OPT)
add_instance_library(device_grouped_conv3d_bwd_data_instance set(GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
...@@ -13,5 +12,11 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance ...@@ -13,5 +12,11 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp)
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp)
endif()
add_instance_library(device_grouped_conv3d_bwd_data_instance ${GROUPED_CONV3D_BWD_DATA})
add_instance_library(device_grouped_conv3d_fwd_instance set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp)
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp)
endif()
add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD})
set(GROUPED_GEMM_FIXED_NK_INSTANCES) set(GROUPED_GEMM_FIXED_NK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp) device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp) device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp
endif() device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp)
endif()
if((DTYPES MATCHES "int8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES})
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 2, 1>>>& std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 4, 3>>>& std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
...@@ -22,25 +22,25 @@ template <typename OutElementwise, index_t Rank, index_t Reduce> ...@@ -22,25 +22,25 @@ template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_instances = using device_normalization_f16_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8> DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on // clang-format on
>; >;
...@@ -48,150 +48,150 @@ template <typename OutElementwise, index_t Rank, index_t Reduce> ...@@ -48,150 +48,150 @@ template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_instances = using device_normalization_splitk_f16_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8> DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple< using device_normalization_f16_generic_instance = std::tuple<
// clang-format off // clang-format off
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1> DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_instances = std::tuple< using device_normalization_f32_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4> DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f32_instances = std::tuple< using device_normalization_splitk_f32_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4> DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_generic_instance = std::tuple< using device_normalization_f32_generic_instance = std::tuple<
// clang-format off // clang-format off
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1> DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_instances = std::tuple< using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4> DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple< using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4> DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
// clang-format off // clang-format off
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1> DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on // clang-format on
>; >;
......
...@@ -22,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} ...@@ -22,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
``` ```
## Profile 2d forward convolution kernels ## Profile 2D forward convolution kernels
```bash ```bash
#arg1: tensor operation (conv=Convolution) #arg1: tensor operation (conv=Convolution)
#arg2: data type (0=fp32, 1=fp16) #arg2: data type (0=fp32, 1=fp16)
...@@ -115,7 +115,7 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s ...@@ -115,7 +115,7 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s
# arg6: print tensor value (0: no; 1: yes) # arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
...@@ -158,7 +158,7 @@ GB/s: 127.947 ...@@ -158,7 +158,7 @@ GB/s: 127.947
# arg6: print tensor value (0: no; 1: yes) # arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
...@@ -201,7 +201,7 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate ...@@ -201,7 +201,7 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# arg8: operation type (0: ImageToColumn, 1: ColumnToImage) # arg8: operation type (0: ImageToColumn, 1: ColumnToImage)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
......
...@@ -80,6 +80,8 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -80,6 +80,8 @@ bool profile_elementwise_layernorm_impl(int do_verification,
Tensor<BetaDataType> beta(gammaBetaLength); Tensor<BetaDataType> beta(gammaBetaLength);
Tensor<YDataType> y(length); Tensor<YDataType> y(length);
Tensor<YDataType> host_y(length); Tensor<YDataType> host_y(length);
Tensor<AccDataType> host_save_mean({M});
Tensor<AccDataType> host_save_inv_std({M});
switch(init_method) switch(init_method)
{ {
...@@ -152,14 +154,23 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -152,14 +154,23 @@ bool profile_elementwise_layernorm_impl(int do_verification,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
......
...@@ -66,12 +66,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -66,12 +66,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType, BetaDataType,
HDataType, HDataType,
AccDataType, AccDataType,
AccDataType,
HElementOp, HElementOp,
2, 2,
1>; 1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N}); Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});
auto ref_gemm = ReferenceGemm{}; auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker(); auto ref_gemm_invoker = ref_gemm.MakeInvoker();
...@@ -97,7 +100,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -97,7 +100,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument( auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument); ref_layernorm_invoker.Run(ref_layernorm_argument);
} }
......
...@@ -21,8 +21,10 @@ namespace profiler { ...@@ -21,8 +21,10 @@ namespace profiler {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType> typename YDataType,
typename SaveMeanInvStdDataType,
bool SaveMeanInvStd>
bool profile_groupnorm_impl(int do_verification, bool profile_groupnorm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -34,6 +36,7 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -34,6 +36,7 @@ bool profile_groupnorm_impl(int do_verification,
if(length.size() != 5) if(length.size() != 5)
return false; return false;
index_t N = length[0];
index_t G = length[3]; index_t G = length[3];
index_t C = length[4]; index_t C = length[4];
...@@ -45,7 +48,14 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -45,7 +48,14 @@ bool profile_groupnorm_impl(int do_verification,
Tensor<GammaDataType> gamma(gammaBetaLength); Tensor<GammaDataType> gamma(gammaBetaLength);
Tensor<BetaDataType> beta(gammaBetaLength); Tensor<BetaDataType> beta(gammaBetaLength);
Tensor<YDataType> y(length); Tensor<YDataType> y(length);
Tensor<SaveMeanInvStdDataType> save_mean({N, G});
Tensor<SaveMeanInvStdDataType> save_inv_std({N, G});
Tensor<YDataType> host_y(length); Tensor<YDataType> host_y(length);
Tensor<SaveMeanInvStdDataType> host_save_mean({N, G});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({N, G});
std::vector<index_t> strideSaveMeanInvStd = {1};
switch(init_method) switch(init_method)
{ {
...@@ -69,6 +79,9 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -69,6 +79,9 @@ bool profile_groupnorm_impl(int do_verification,
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -78,8 +91,8 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -78,8 +91,8 @@ bool profile_groupnorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
5, 5,
3>; 3>;
...@@ -97,38 +110,70 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -97,38 +110,70 @@ bool profile_groupnorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType, using ReferenceInstance =
GammaDataType, ck::tensor_operation::host::ReferenceGroupnorm<XDataType,
BetaDataType, GammaDataType,
YDataType, BetaDataType,
AccDataType, YDataType,
PassThrough>; SaveMeanInvStdDataType,
ComputeDataType,
PassThrough>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, 1e-6); auto ref_argument = ref.MakeArgument(
auto ref_invoker = ref.MakeInvoker(); x, gamma, beta, host_y, host_save_mean, host_save_inv_std, PassThrough{}, length, 1e-6);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; int num_kernel = 0;
auto f_get_argument = [&](auto& inst_ptr) {
if constexpr(SaveMeanInvStd)
return inst_ptr->MakeArgumentPointer(
length,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
gammaBetaStride,
gammaBetaStride,
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_inv_std.mDesc.GetStrides().begin(),
save_inv_std.mDesc.GetStrides().end()},
reduce_dim,
1e-6,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
PassThrough{});
else
return inst_ptr->MakeArgumentPointer(
length,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
gammaBetaStride,
gammaBetaStride,
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_inv_std.mDesc.GetStrides().begin(),
save_inv_std.mDesc.GetStrides().end()},
reduce_dim,
1e-6,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
};
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
auto argument_ptr = inst_ptr->MakeArgumentPointer( auto argument_ptr = f_get_argument(inst_ptr);
length,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
gammaBetaStride,
gammaBetaStride,
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
reduce_dim,
1e-6,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(inst_ptr->IsSupportedArgument(argument_ptr.get())) if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -152,6 +197,10 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -152,6 +197,10 @@ bool profile_groupnorm_impl(int do_verification,
beta.mDesc.GetElementSize() * sizeof(BetaDataType) + beta.mDesc.GetElementSize() * sizeof(BetaDataType) +
y.mDesc.GetElementSize() * sizeof(YDataType); y.mDesc.GetElementSize() * sizeof(YDataType);
if constexpr(SaveMeanInvStd)
num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) +
save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel) if(time_kernel)
...@@ -168,9 +217,22 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -168,9 +217,22 @@ bool profile_groupnorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
if constexpr(SaveMeanInvStd)
{
save_mean_dev.FromDevice(save_mean.mData.data());
pass &= ck::utils::check_err(
save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3);
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(save_inv_std.mData,
host_save_inv_std.mData,
"Error: Incorrect results",
1e-3,
1e-3);
}
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl;
......
...@@ -21,6 +21,8 @@ template <typename XDataType, ...@@ -21,6 +21,8 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
bool SaveMeanInvStd,
index_t Rank> index_t Rank>
bool profile_layernorm_impl(int do_verification, bool profile_layernorm_impl(int do_verification,
int init_method, int init_method,
...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification, ...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification,
Tensor<GammaDataType> gamma(reduce_length); Tensor<GammaDataType> gamma(reduce_length);
Tensor<BetaDataType> beta(reduce_length); Tensor<BetaDataType> beta(reduce_length);
Tensor<YDataType> y(length); Tensor<YDataType> y(length);
Tensor<SaveMeanInvStdDataType> save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> save_inv_std({length[0]});
Tensor<YDataType> host_y(length); Tensor<YDataType> host_y(length);
Tensor<SaveMeanInvStdDataType> host_save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({length[0]});
std::vector<index_t> strideXY = std::vector<index_t> strideXY =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideGammaBeta = strideXY; std::vector<index_t> strideGammaBeta = strideXY;
strideGammaBeta[0] = 0; strideGammaBeta[0] = 0;
std::vector<index_t> strideSaveMeanInvStd = {1};
switch(init_method) switch(init_method)
{ {
case 0: case 0:
...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification, ...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification,
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification, ...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification, ...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, using ReferenceInstance =
GammaDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
BetaDataType, GammaDataType,
YDataType, BetaDataType,
ComputeDataType, YDataType,
PassThrough, SaveMeanInvStdDataType,
Rank, ComputeDataType,
NumReduceDim>; PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
length,
reduce_dim,
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; int num_kernel = 0;
auto f_get_argument = [&](auto& inst_ptr) {
if constexpr(SaveMeanInvStd)
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
PassThrough{});
else
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
};
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
auto argument_ptr = inst_ptr->MakeArgumentPointer(length, auto argument_ptr = f_get_argument(inst_ptr);
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(inst_ptr->IsSupportedArgument(argument_ptr.get())) if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification, ...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification,
beta.mDesc.GetElementSize() * sizeof(BetaDataType) + beta.mDesc.GetElementSize() * sizeof(BetaDataType) +
y.mDesc.GetElementSize() * sizeof(YDataType); y.mDesc.GetElementSize() * sizeof(YDataType);
if constexpr(SaveMeanInvStd)
num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) +
save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel) if(time_kernel)
...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification, ...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if constexpr(SaveMeanInvStd)
{
save_mean_dev.FromDevice(save_mean.mData.data());
pass &= ck::utils::check_err(
save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3);
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(save_inv_std.mData,
host_save_inv_std.mData,
"Error: Incorrect results",
1e-3,
1e-3);
}
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl;
......
...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES ...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp profile_batchnorm_infer.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp profile_conv_tensor_rearrange.cpp
) )
...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif() endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan ...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d ...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
if(DL_KERNELS) if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
......
...@@ -86,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -86,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#ifdef CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
...@@ -141,59 +137,59 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -141,59 +137,59 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
...@@ -204,22 +200,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -204,22 +200,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
} }
......
...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[]) ...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Float) if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32>( ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Half) else if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16>( ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half) if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>( ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>( ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment