Commit 50a6b8d7 authored by Chao Liu's avatar Chao Liu
Browse files

update dynamic gemm

parent b90cccf7
......@@ -253,6 +253,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 1
// LDS double buffer: preload data into LDS
{
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
......@@ -261,26 +262,36 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
#endif
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
// even iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
......@@ -291,14 +302,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
block_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next);
}
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
}
#if 1
// LDS double buffer: tail
{
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
......@@ -334,6 +345,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
#endif
// output: register to global memory
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment