Commit 9b7093aa authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' into lwpck-791

parents 4cb63955 bba085d2
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_bilinear_instance add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
...@@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance ...@@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_fastgelu_instance add_instance_library(device_gemm_fastgelu_instance
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
) )
endif()
set(GEMM_MULTIPLY_ADD_INSTANCES) set(GEMM_MULTIPLY_ADD_INSTANCES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp) device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp) device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES}) add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
set(GEMM_SPLITK_INSTANCES) set(GEMM_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp) device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp) device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp) device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp) device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
endif() device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp) device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp) device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp) device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp) device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp
endif() device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp) device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
endif()
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_streamk_instance add_instance_library(device_gemm_streamk_instance
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
...@@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance ...@@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp # device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp # device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_grouped_gemm_instance add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...@@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance ...@@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_grouped_gemm_fastgelu_instance add_instance_library(device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
) )
endif()
set(DEVICE_MAXPOOL_BWD_INSTANCES) set(DEVICE_MAXPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp) device_max_pool_bwd_bf16_instance.cpp
endif() device_max_pool_bwd_f32_instance.cpp)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp)
endif()
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES}) add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})
set(DEVICE_NORMALIZATION_INSTANCES) set(DEVICE_NORMALIZATION_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp list(APPEND DEVICE_NORMALIZATION_INSTANCES
device_layernorm2d_f16_instance.cpp
device_layernorm4d_f16_instance.cpp device_layernorm4d_f16_instance.cpp
device_groupnorm_f16_instance.cpp device_groupnorm_f16_instance.cpp
device_groupnorm_swish_f16_instance.cpp device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp) device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
endif() device_layernorm2d_f32_instance.cpp
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
device_layernorm4d_f32_instance.cpp device_layernorm4d_f32_instance.cpp
device_groupnorm_f32_instance.cpp device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f32_instance.cpp) device_groupnorm_swish_f32_instance.cpp)
endif()
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES}) add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
set(DEVICE_POOL3D_FWD_INSTANCES) set(DEVICE_POOL3D_FWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp device_max_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp) device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
endif() device_max_pool3d_fwd_ndhwc_f32_instance.cpp
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES}) add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
...@@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC ...@@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
) )
if(DL_KERNELS)
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp) list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp) list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp) list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp) list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
list(APPEND GEMM_QUANT_SRC list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_quantization_instance add_instance_library(device_quantization_instance
${CONV2D_PERLAYER_QUANT_SRC} ${CONV2D_PERLAYER_QUANT_SRC}
...@@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance ...@@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC} ${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
${GEMM_QUANT_SRC} ${GEMM_QUANT_SRC}
) )
endif()
\ No newline at end of file
set(DEVICE_SOFTMAX_INSTANCES) set(DEVICE_SOFTMAX_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_SOFTMAX_INSTANCES
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp device_softmax_f16_f16_instance_rank3_reduce1.cpp
device_softmax_f16_f16_instance_rank3_reduce2.cpp device_softmax_f16_f16_instance_rank3_reduce2.cpp
device_softmax_f16_f16_instance_rank3_reduce3.cpp device_softmax_f16_f16_instance_rank3_reduce3.cpp
device_softmax_f16_f16_instance_rank4_reduce1.cpp device_softmax_f16_f16_instance_rank4_reduce1.cpp
device_softmax_f16_f16_instance_rank4_reduce2.cpp device_softmax_f16_f16_instance_rank4_reduce2.cpp
device_softmax_f16_f16_instance_rank4_reduce3.cpp device_softmax_f16_f16_instance_rank4_reduce3.cpp
device_softmax_f16_f16_instance_rank4_reduce4.cpp) device_softmax_f16_f16_instance_rank4_reduce4.cpp
endif() device_softmax_f32_f32_instance_rank3_reduce1.cpp
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp
device_softmax_f32_f32_instance_rank3_reduce2.cpp device_softmax_f32_f32_instance_rank3_reduce2.cpp
device_softmax_f32_f32_instance_rank3_reduce3.cpp device_softmax_f32_f32_instance_rank3_reduce3.cpp
device_softmax_f32_f32_instance_rank4_reduce1.cpp device_softmax_f32_f32_instance_rank4_reduce1.cpp
device_softmax_f32_f32_instance_rank4_reduce2.cpp device_softmax_f32_f32_instance_rank4_reduce2.cpp
device_softmax_f32_f32_instance_rank4_reduce3.cpp device_softmax_f32_f32_instance_rank4_reduce3.cpp
device_softmax_f32_f32_instance_rank4_reduce4.cpp) device_softmax_f32_f32_instance_rank4_reduce4.cpp)
endif()
add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES}) add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES})
...@@ -9,26 +9,121 @@ add_custom_target(tests) ...@@ -9,26 +9,121 @@ add_custom_target(tests)
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
add_executable(${TEST_NAME} ${ARGN}) set(result 1)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) if(DEFINED DTYPES)
add_dependencies(tests ${TEST_NAME}) foreach(source IN LISTS ARGN)
add_dependencies(check ${TEST_NAME}) set(test 0)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
set(result 0)
endif()
#message("add_test returns ${result}")
return(PROPAGATE result)
endfunction(add_test_executable TEST_NAME) endfunction(add_test_executable TEST_NAME)
include(GoogleTest) include(GoogleTest)
function(add_gtest_executable TEST_NAME) function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}") message("adding gtest ${TEST_NAME}")
add_executable(${TEST_NAME} ${ARGN}) set(result 1)
add_dependencies(tests ${TEST_NAME}) if(DEFINED DTYPES)
add_dependencies(check ${TEST_NAME}) foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
# suppress gtest warnings # suppress gtest warnings
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main) target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
set(result 0)
endif()
#message("add_gtest returns ${result}")
return(PROPAGATE result)
endfunction(add_gtest_executable TEST_NAME) endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division) add_subdirectory(magic_number_division)
......
...@@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility) target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
endif() endif()
set(target 1) set(target 1)
endif() endif()
......
...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_gemm)
add_custom_target(test_batched_gemm_gemm) add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
if(DL_KERNELS) add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d_dl.cpp)
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
endif() endif()
...@@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
set(target 1) set(target 1)
endif() endif()
endif() endif()
......
...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_softmax_gemm)
add_custom_target(test_batched_gemm_softmax_gemm) add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_custom_target(test_batched_gemm_softmax_gemm_permute) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
endif() if(result EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) endif()
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) if(result EQUAL 0)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) endif()
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
endif() if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment