Commit 7c2b82ca authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add option to enable/disable compiler opt

parent da265f69
......@@ -56,7 +56,8 @@ struct GridwiseGemmPipeline_v1<1>
CThreadBuffer& c_thread_buf,
index_t num_loop
#if ENABLE_DUMP_CLOCK
, long& loop_start,
,
long& loop_start,
long& loop_end
#endif
)
......@@ -94,6 +95,10 @@ struct GridwiseGemmPipeline_v1<1>
do
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt(1);
#endif
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
......@@ -183,7 +188,8 @@ struct GridwiseGemmPipeline_v1<2>
CThreadBuffer& c_thread_buf,
index_t num_loop
#if ENABLE_DUMP_CLOCK
, long& loop_start,
,
long& loop_start,
long& loop_end
#endif
)
......@@ -226,6 +232,10 @@ struct GridwiseGemmPipeline_v1<2>
do
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt(1);
#endif
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
......@@ -355,7 +365,8 @@ struct GridwiseGemmPipelineInterwave_v1<1>
CThreadBuffer& c_thread_buf,
index_t num_loop
#if ENABLE_DUMP_CLOCK
, long& loop_start,
,
long& loop_start,
long& loop_end
#endif
)
......@@ -393,6 +404,10 @@ struct GridwiseGemmPipelineInterwave_v1<1>
do
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt(1);
#endif
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
......
......@@ -51,7 +51,8 @@ struct GridwiseGemmPipeline_v2
CThreadBuffer& c_thread_buf,
index_t num_loop
#if ENABLE_DUMP_CLOCK
, long& loop_start,
,
long& loop_start,
long& loop_end
#endif
)
......@@ -97,6 +98,10 @@ struct GridwiseGemmPipeline_v2
do
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt(1);
#endif
block_sync_lds();
// GEMM i
......
......@@ -18,6 +18,9 @@
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop>
#ifdef USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(1, 1)))
#endif
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -833,7 +836,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf,
num_k_block_main_loop
#if ENABLE_DUMP_CLOCK
, loop_start, loop_end
,
loop_start,
loop_end
#endif
);
......
......@@ -18,6 +18,9 @@
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop>
#ifdef USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(1, 1)))
#endif
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -678,7 +681,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf,
num_k_block_main_loop
#if ENABLE_DUMP_CLOCK
, loop_start, loop_end
,
loop_start,
loop_end
#endif
);
......
......@@ -41,3 +41,16 @@ add_instance_library(device_gemm_instance
# device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
# device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
)
set(ENABLE_IGLP_OPT OFF)
if(ENABLE_IGLP_OPT)
set_source_files_properties(device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS ";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS ";USE_IGLP_OPT;USE_WAVES_PER_EU;")
set_source_files_properties(device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS ";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS ";USE_IGLP_OPT;USE_WAVES_PER_EU;")
set_source_files_properties(device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS ";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS ";USE_IGLP_OPT;USE_WAVES_PER_EU;")
endif(ENABLE_IGLP_OPT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment