Commit 50a6b8d7 authored by Chao Liu's avatar Chao Liu
Browse files

update dynamic gemm

parent b90cccf7
...@@ -253,6 +253,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -253,6 +253,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 1
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
...@@ -261,44 +262,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -261,44 +262,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
} }
#endif
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body // LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock; for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock) k_block_data_begin += 2 * KPerBlock)
{ {
#pragma unroll // even iteration
for(index_t iloop = 0; iloop < 2; ++iloop) a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
{ b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now = __syncthreads();
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next = // LDS doubel buffer: load next data from device mem
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double; a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
Float* p_b_block_next = b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); // LDS double buffer: GEMM on current data
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); block_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
__syncthreads(); // LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// LDS doubel buffer: load next data from device mem // odd iteration
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
// LDS double buffer: GEMM on current data __syncthreads();
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // LDS doubel buffer: load next data from device mem
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next); a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next); b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
}
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
} }
#if 1
// LDS double buffer: tail // LDS double buffer: tail
{ {
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
...@@ -334,6 +345,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -334,6 +345,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
} }
} }
#endif
// output: register to global memory // output: register to global memory
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment