Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7c2b82ca
Commit
7c2b82ca
authored
May 11, 2023
by
Po-Yen, Chen
Browse files
Add option to enable/disable compiler opt
parent
da265f69
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
12 deletions
+55
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+21
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+7
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+7
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+7
-2
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+13
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
7c2b82ca
...
@@ -56,10 +56,11 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -56,10 +56,11 @@ struct GridwiseGemmPipeline_v1<1>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
index_t
num_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
long
&
loop_start
,
,
long
&
loop_start
,
long
&
loop_end
long
&
loop_end
#endif
#endif
)
)
{
{
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -94,6 +95,10 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -94,6 +95,10 @@ struct GridwiseGemmPipeline_v1<1>
do
do
{
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
1
);
#endif
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -183,10 +188,11 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -183,10 +188,11 @@ struct GridwiseGemmPipeline_v1<2>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
index_t
num_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
long
&
loop_start
,
,
long
&
loop_start
,
long
&
loop_end
long
&
loop_end
#endif
#endif
)
)
{
{
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -226,6 +232,10 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -226,6 +232,10 @@ struct GridwiseGemmPipeline_v1<2>
do
do
{
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
1
);
#endif
// Move
// Move
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -355,10 +365,11 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -355,10 +365,11 @@ struct GridwiseGemmPipelineInterwave_v1<1>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
index_t
num_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
long
&
loop_start
,
,
long
&
loop_start
,
long
&
loop_end
long
&
loop_end
#endif
#endif
)
)
{
{
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -393,6 +404,10 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -393,6 +404,10 @@ struct GridwiseGemmPipelineInterwave_v1<1>
do
do
{
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
1
);
#endif
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
block_sync_lds
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
7c2b82ca
...
@@ -51,10 +51,11 @@ struct GridwiseGemmPipeline_v2
...
@@ -51,10 +51,11 @@ struct GridwiseGemmPipeline_v2
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
index_t
num_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
long
&
loop_start
,
,
long
&
loop_start
,
long
&
loop_end
long
&
loop_end
#endif
#endif
)
)
{
{
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -97,6 +98,10 @@ struct GridwiseGemmPipeline_v2
...
@@ -97,6 +98,10 @@ struct GridwiseGemmPipeline_v2
do
do
{
{
#ifdef USE_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
1
);
#endif
block_sync_lds
();
block_sync_lds
();
// GEMM i
// GEMM i
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
7c2b82ca
...
@@ -18,6 +18,9 @@
...
@@ -18,6 +18,9 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
#ifdef USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
#endif
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -833,9 +836,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -833,9 +836,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
num_k_block_main_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
loop_start
,
loop_end
,
loop_start
,
loop_end
#endif
#endif
);
);
// shuffle C and write out
// shuffle C and write out
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
7c2b82ca
...
@@ -18,6 +18,9 @@
...
@@ -18,6 +18,9 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
#ifdef USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
#endif
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -678,9 +681,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -678,9 +681,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
num_k_block_main_loop
#if ENABLE_DUMP_CLOCK
#if ENABLE_DUMP_CLOCK
,
loop_start
,
loop_end
,
loop_start
,
loop_end
#endif
#endif
);
);
// output: register to global memory
// output: register to global memory
{
{
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
7c2b82ca
...
@@ -41,3 +41,16 @@ add_instance_library(device_gemm_instance
...
@@ -41,3 +41,16 @@ add_instance_library(device_gemm_instance
# device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
# device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
# device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
# device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
)
)
set
(
ENABLE_IGLP_OPT OFF
)
if
(
ENABLE_IGLP_OPT
)
set_source_files_properties
(
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS
";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS
";USE_IGLP_OPT;USE_WAVES_PER_EU;"
)
set_source_files_properties
(
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS
";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS
";USE_IGLP_OPT;USE_WAVES_PER_EU;"
)
set_source_files_properties
(
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp PROPERTIES
COMPILE_OPTIONS
";--save-temps;-Wno-gnu-line-marker;-mllvm;-amdgpu-enable-max-ilp-scheduling-strategy;"
COMPILE_DEFINITIONS
";USE_IGLP_OPT;USE_WAVES_PER_EU;"
)
endif
(
ENABLE_IGLP_OPT
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment