Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
49282565
Commit
49282565
authored
Nov 05, 2023
by
Jing Zhang
Browse files
add grouped gemm instalces
parent
01ee5e53
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
448 additions
and
168 deletions
+448
-168
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+7
-1
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+78
-76
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
...emm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
+124
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
...emm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
+124
-0
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+92
-89
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+22
-1
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
49282565
...
...
@@ -16,7 +16,7 @@ namespace element_wise {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThrough
struct
PassThrough
Pack2
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
...
...
@@ -64,6 +64,12 @@ struct PassThrough
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
49282565
...
...
@@ -4,97 +4,99 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
#
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
#
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
#
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
#
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
list
(
APPEND GEMM_INSTANCES
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
#
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
#
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
#
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
#
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
)
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
#
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
#
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
#
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
#
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
#
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
#
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
#
device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp
#
device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp
#
device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp
#
device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp
#
device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
#
device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
#
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
#
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
#
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
#
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
#
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
#
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
#
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
#
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
#
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
list
(
APPEND GEMM_INSTANCES
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
list
(
APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp
list
(
APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
...
...
@@ -108,21 +110,21 @@ if (ENABLE_PIPELINE_V2_OPT)
CK_USE_WAVES_PER_EU=1
CK_MIN_WAVES_PER_EU=1
CK_MAX_WAVES_PER_EU=1
)
)
set
(
IGLP_OPT_DEFS
CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT=1
)
)
# TODO: The "-vectorize-slp=false" LLVM option is a workaround to prevent inefficient instruction scheduling
# caused by the SLP Vectorizer. Remove this option after fix the SLP Vectorizer issue.
# layout=NT
set_source_files_properties
(
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS
";-mllvm;-vectorize-slp=false"
COMPILE_DEFINITIONS
"
${
WAVES_PER_EU_DEFS
}
;
${
IGLP_OPT_DEFS
}
"
)
COMPILE_OPTIONS
";-mllvm;-vectorize-slp=false"
COMPILE_DEFINITIONS
"
${
WAVES_PER_EU_DEFS
}
;
${
IGLP_OPT_DEFS
}
"
)
# layout=NN
set_source_files_properties
(
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS
";-mllvm;-vectorize-slp=false"
COMPILE_DEFINITIONS
"
${
WAVES_PER_EU_DEFS
}
;
${
IGLP_OPT_DEFS
}
"
)
COMPILE_OPTIONS
";-mllvm;-vectorize-slp=false"
COMPILE_DEFINITIONS
"
${
WAVES_PER_EU_DEFS
}
;
${
IGLP_OPT_DEFS
}
"
)
# layout=TT
set_source_files_properties
(
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS
";;"
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
49282565
...
...
@@ -7,6 +7,6 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
)
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
0 → 100644
View file @
49282565
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
0 → 100644
View file @
49282565
This diff is collapsed.
Click to expand it.
profiler/src/CMakeLists.txt
View file @
49282565
...
...
@@ -3,51 +3,53 @@ set(PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
#
profile_gemm_bias_add_reduce.cpp
#
profile_gemm_add_multiply.cpp
#
profile_gemm_multiply_add.cpp
#
profile_gemm_reduce.cpp
#
profile_batched_gemm.cpp
#
profile_batched_gemm_reduce.cpp
#
profile_conv_fwd.cpp
#
profile_conv_fwd_bias_relu.cpp
#
profile_conv_fwd_bias_relu_add.cpp
#
profile_conv_bwd_data.cpp
#
profile_grouped_conv_fwd.cpp
#
profile_grouped_conv_bwd_weight.cpp
#
profile_reduce.cpp
#
profile_groupnorm.cpp
#
profile_layernorm.cpp
#
profile_max_pool3d_fwd.cpp
#
profile_avg_pool3d_bwd.cpp
#
profile_max_pool3d_bwd.cpp
#
profile_softmax.cpp
#
profile_batchnorm_fwd.cpp
#
profile_batchnorm_bwd.cpp
#
profile_batchnorm_infer.cpp
#
profile_grouped_conv_bwd_data.cpp
#
profile_conv_tensor_rearrange.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm.cpp
profile_layernorm.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
)
#if(DL_KERNELS)
#list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
#endif()
#if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
#list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
#list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
#list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
if
(
DL_KERNELS
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_streamk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm.cpp
)
#
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
#
endif()
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
endif
()
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
#
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
#
endif()
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
endif
()
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
...
@@ -57,57 +59,58 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv1d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_avg_pool3d_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_max_pool_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_image_to_column_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_column_to_image_instance
)
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
#
endif()
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
endif
()
#
if(DL_KERNELS)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
#
endif()
if
(
DL_KERNELS
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_multi_d_instance
)
endif
()
#
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
#endif()
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
endif
()
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
profiler/src/profile_grouped_gemm.cpp
View file @
49282565
...
...
@@ -27,7 +27,8 @@ enum struct GemmDataType
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F16_F8_F16
,
// 4
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
};
#define OP_NAME "grouped_gemm"
...
...
@@ -170,6 +171,26 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideCs
,
kbatch
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
f8_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment