Commit c9013009 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 114c2646 84dcf5d0
...@@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl) add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) if(result EQUAL 0)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
endif()
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
endif() endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
endif() endif()
...@@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
endif() endif()
set(target 1) set(target 1)
endif() endif()
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) if(result EQUAL 0)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
if(DL_KERNELS)
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) if(result EQUAL 0)
endif() target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
...@@ -3,22 +3,20 @@ set(target 0) ...@@ -3,22 +3,20 @@ set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight) add_custom_target(example_grouped_conv_bwd_weight)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
endif() endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(example_grouped_conv_bwd_weight_dl)
if(DL_KERNELS) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
add_custom_target(example_grouped_conv_bwd_weight_dl) if(result EQUAL 0)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) endif()
endif()
endif()
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
using InDataType = F16; using InDataType = F16;
using WeiDataType = F16; using WeiDataType = F16;
...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough; ...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough;
using OutElementOp = PassThrough; using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< NDimSpatial, // NDimSpatial
NDimSpatial, // NDimSpatial ck::tuple_element_t<NDimSpatial - 1,
InDataType, // InDataType ck::Tuple<ck::tensor_layout::convolution::GNWC,
WeiDataType, // WeiDataType ck::tensor_layout::convolution::GNHWC,
OutDataType, // OutDataType ck::tensor_layout::convolution::GNDHWC>>, // InLayout
AccDataType, // AccDataType ck::tuple_element_t<NDimSpatial - 1,
InElementOp, // InElementwiseOperation ck::Tuple<ck::tensor_layout::convolution::GKXC,
WeiElementOp, // WeiElementwiseOperation ck::tensor_layout::convolution::GKYXC,
OutElementOp, // OutElementwiseOperation ck::tensor_layout::convolution::GKZYXC>>, // WeiLayout
ConvBwdWeightDefault, // ConvBackwardWeightSpecialization ck::tuple_element_t<NDimSpatial - 1,
256, // BlockSize ck::Tuple<ck::tensor_layout::convolution::GNWK,
128, // MPerBlock ck::tensor_layout::convolution::GNHWK,
128, // NPerBlock ck::tensor_layout::convolution::GNDHWK>>, // OutLayout
16, // K0PerBlock InDataType, // InDataType
2, // K1 WeiDataType, // WeiDataType
4, // M1PerThread OutDataType, // OutDataType
4, // N1PerThread AccDataType, // AccDataType
1, // KPerThread InElementOp, // InElementwiseOperation
S<8, 2>, // M1N1ThreadClusterM1Xs WeiElementOp, // WeiElementwiseOperation
S<8, 2>, // M1N1ThreadClusterN1Xs OutElementOp, // OutElementwiseOperation
S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 ConvBwdWeightDefault, // ConvBackwardWeightSpecialization
S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 256, // BlockSize
S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder 128, // MPerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder 128, // NPerBlock
S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 16, // K0PerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder 2, // K1
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 4, // M1PerThread
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 4, // N1PerThread
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 1, // KPerThread
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder S<8, 2>, // M1N1ThreadClusterM1Xs
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder S<8, 2>, // M1N1ThreadClusterN1Xs
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
5, // CThreadTransferSrcDstVectorDim S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
4>; // CThreadTransferDstScalarPerVector S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
......
...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial> ...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
ck::index_t split_k;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1 // Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support constexpr ck::index_t split_k = 1;
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{
split_k = 2;
}
else
{
split_k = 1;
}
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
add_custom_target(example_cgemm_xdl) add_custom_target(example_cgemm_xdl)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
......
add_custom_target(example_batched_gemm_xdl) add_custom_target(example_batched_gemm_xdl)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
endif() add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
endif() endif()
...@@ -5,27 +5,31 @@ set(target 0) ...@@ -5,27 +5,31 @@ set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_custom_target(example_grouped_conv_fwd_multiple_d) add_custom_target(example_grouped_conv_fwd_multiple_d)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) endif()
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif()
endif() # USE_BITINT_EXTENSION_INT4 endif() # USE_BITINT_EXTENSION_INT4
set(target 1) set(target 1)
...@@ -35,12 +39,8 @@ endforeach() ...@@ -35,12 +39,8 @@ endforeach()
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0) if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list2 gfx908 gfx90a)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
endif() add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
...@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach() endforeach()
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
endif()
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
endif() endif()
...@@ -3,25 +3,28 @@ set(target 0) ...@@ -3,25 +3,28 @@ set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_splitK_gemm_xdl) add_custom_target(example_splitK_gemm_xdl)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) if(result EQUAL 0)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
endif() endif()
set(target 1) set(target 1)
endif() endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment