Commit bec0399a authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel...

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel into barkocot/conv-elementwise
parents 6d149b65 bf435140
...@@ -32,12 +32,10 @@ if (DTYPES) ...@@ -32,12 +32,10 @@ if (DTYPES)
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8) add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON") set(CK_ENABLE_FP8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "bf8") if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8) add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON") set(CK_ENABLE_BF8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16) add_definitions(-DCK_ENABLE_FP16)
...@@ -59,9 +57,11 @@ if (DTYPES) ...@@ -59,9 +57,11 @@ if (DTYPES)
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_ALL_DTYPES "ON")
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
endif() endif()
#for f8/bf8_t type
add_compile_options(-Wno-bit-int-extension)
if(DL_KERNELS) if(DL_KERNELS)
add_definitions(-DDL_KERNELS) add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON") set(CK_ENABLE_DL_KERNELS "ON")
......
add_custom_target(example_gemm_dl) add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_fp32)
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
endif()
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
endif()
add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_int8)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int4) add_example_dependencies(example_gemm_dl example_gemm_dl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_custom_target(example_gemm_xdl) add_custom_target(example_gemm_xdl)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
endif()
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
endif()
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
endif()
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
endif() endif()
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
endif()
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
endif()
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_add_add_fastgelu_xdl) add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
set(target 1)
endif() endif()
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) endforeach()
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
endif()
set(target 1)
endif()
endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl) add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
if(result EQUAL 0) add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
endif()
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
endif() add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
endif() add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
if(result EQUAL 0) if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
endif() add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
if(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) set(target 1)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) endif()
endif(USE_BITINT_EXTENSION_INT4) endforeach()
set(target 1)
endif()
endforeach()
\ No newline at end of file
add_custom_target(example_grouped_gemm_xdl) add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
endif()
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
endif()
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
endif() endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl) add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl) add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
endif() add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
endif() add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
endif() add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
endif() add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
endif() add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) add_example_dependencies(example_gemm_reduce_xdl
endif() example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) example_gemm_add_add_mean_meansquare_xdl)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) if(USE_BITINT_EXTENSION_INT4)
endif() add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
if(result EQUAL 0) endif()
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) set(target 1)
endif() endif()
add_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
endif()
set(target 1)
endif()
endforeach() endforeach()
...@@ -2,36 +2,30 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,36 +2,30 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight) add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
endif() add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
endif() add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) endif()
if(result EQUAL 0) set(target 1)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
set(target 1)
endif() endif()
endif()
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
endif()
set(target 1)
endif()
endforeach() endforeach()
add_custom_target(example_grouped_conv_bwd_weight_dl) add_custom_target(example_grouped_conv_bwd_weight_dl)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
endif()
add_custom_target(example_cgemm_xdl) add_custom_target(example_cgemm_xdl)
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
endif()
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
endif()
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
endif()
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
endif() endif()
add_custom_target(example_batched_gemm_xdl) add_custom_target(example_batched_gemm_xdl)
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
endif()
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
endif()
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
endif() endif()
...@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) ...@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_custom_target(example_grouped_conv_fwd_multiple_d) add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
endif()
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
endif() add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
endif() add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
endif()
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) if(USE_BITINT_EXTENSION_INT4)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() endif() # USE_BITINT_EXTENSION_INT4
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) set(target 1)
if(result EQUAL 0) endif()
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif()
endif() # USE_BITINT_EXTENSION_INT4
set(target 1)
endif()
endforeach() endforeach()
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0) if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
endif()
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
endif()
...@@ -4,28 +4,23 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -4,28 +4,23 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_splitK_gemm_xdl) add_custom_target(example_splitK_gemm_xdl)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
endif() add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
endif() add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
endif()
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
if(result EQUAL 0) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
endif() endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -2,27 +2,26 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,27 +2,26 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data) add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
endif()
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
endif() set(target 1)
set(target 1) endif()
endif()
endforeach() endforeach()
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data) add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
add_custom_target(example_permute) add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_permute example_permute_1xHxW_fp16)
add_dependencies(example_permute example_permute_1xHxW_fp16)
endif()
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_permute example_permute_NxHxW_fp16)
add_dependencies(example_permute example_permute_NxHxW_fp16)
endif()
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_permute example_permute_HxWx4_fp16)
add_dependencies(example_permute example_permute_HxWx4_fp16)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_im2col_col2im) add_custom_target(example_im2col_col2im)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_dependencies(example_im2col_col2im example_image_to_column_f32) add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_dependencies(example_im2col_col2im example_column_to_image_f32)
set(target 1) add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
endif() add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
set(target 1)
endif()
endforeach() endforeach()
...@@ -62,6 +62,12 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -62,6 +62,12 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(result EQUAL 0)
add_dependencies(${EXAMPLE_NAME} ${FILE_NAME})
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
set(result 1) set(result 1)
......
...@@ -113,7 +113,6 @@ struct PassThrough ...@@ -113,7 +113,6 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -143,9 +142,7 @@ struct PassThrough ...@@ -143,9 +142,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
#if defined CK_ENABLE_BF8
template <> template <>
__host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const __host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
{ {
...@@ -175,7 +172,6 @@ struct PassThrough ...@@ -175,7 +172,6 @@ struct PassThrough
{ {
y = ck::type_convert<bf8_t>(x); y = ck::type_convert<bf8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -204,7 +200,6 @@ struct ConvertBF16RTN ...@@ -204,7 +200,6 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -221,7 +216,6 @@ struct ConvertF8SR ...@@ -221,7 +216,6 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
......
...@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{ {
...@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8> ...@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{ {
...@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8> ...@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{ {
...@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
...@@ -792,7 +784,6 @@ struct MfmaSelector ...@@ -792,7 +784,6 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -804,9 +795,7 @@ struct MfmaSelector ...@@ -804,9 +795,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
#if defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() static constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
...@@ -818,9 +807,7 @@ struct MfmaSelector ...@@ -818,9 +807,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
...@@ -832,9 +819,7 @@ struct MfmaSelector ...@@ -832,9 +819,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
...@@ -846,7 +831,6 @@ struct MfmaSelector ...@@ -846,7 +831,6 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
#endif
static constexpr auto selected_mfma = static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{}; mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
...@@ -1051,18 +1035,10 @@ struct XdlopsGemm ...@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert( static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value || is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
#if defined CK_ENABLE_FP8 is_same<base_type, bf8_t>::value ||
|| is_same<base_type, f8_t>::value (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
#endif (is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
#if defined CK_ENABLE_BF8
|| is_same<base_type, bf8_t>::value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value)
#endif
,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP #pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8; struct intrin_mfma_f32_32x32x16bf8bf8;
...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8; struct intrin_mfma_f32_32x32x16f8bf8;
...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8; struct intrin_mfma_f32_32x32x16bf8f8;
...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif
...@@ -9,15 +9,9 @@ namespace ck { ...@@ -9,15 +9,9 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4);
using int4_t = _BitInt(4); using f8_t = _BitInt(8);
#endif using bf8_t = unsigned _BitInt(8);
#if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -148,23 +142,19 @@ struct scalar_type<int4_t> ...@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct scalar_type<bf8_t> struct scalar_type<bf8_t>
{ {
using type = bf8_t; using type = bf8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8 // bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type; using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type; using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type; using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t> ...@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
...@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t> ...@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct NumericLimits<bf8_t> struct NumericLimits<bf8_t>
{ {
...@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t> ...@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
...@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t> ...@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericUtils<f8_t> struct NumericUtils<f8_t>
{ {
static constexpr int exp = 4; static constexpr int exp = 4;
static constexpr int mant = 3; static constexpr int mant = 3;
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct NumericUtils<bf8_t> struct NumericUtils<bf8_t>
{ {
static constexpr int exp = 5; static constexpr int exp = 5;
static constexpr int mant = 2; static constexpr int mant = 2;
}; };
#endif //
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment