Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
50a6b8d7
Commit
50a6b8d7
authored
Jan 06, 2021
by
Chao Liu
Browse files
update dynamic gemm
parent
b90cccf7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
24 deletions
+36
-24
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+36
-24
No files found.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
50a6b8d7
...
...
@@ -253,6 +253,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
#if 1
// LDS double buffer: preload data into LDS
{
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
...
...
@@ -261,44 +262,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
);
}
#endif
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
// even iteration
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space_size
;
__syncthreads
();
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space_size
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space_size
:
p_b_block_double
;
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
__syncthreads
();
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_odd
);
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_
g
lo
bal
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_
g
lo
bal
);
// odd iteration
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_
b
lo
ck_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_
b
lo
ck_slice_copy_step
);
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
__syncthreads
();
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
);
}
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_even
);
}
#if 1
// LDS double buffer: tail
{
if
constexpr
(
IsEvenNumberKBlockLoop
)
// if has 2 iteration left
...
...
@@ -334,6 +345,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
#endif
// output: register to global memory
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment