Commit 8e35a579 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 89123dd7
...@@ -131,7 +131,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -131,7 +131,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index(KPerBlock, NPerBlock), max_lds_align); make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_block_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>, Sequence<KPerBlock, MPerBlock>,
...@@ -160,7 +160,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -160,7 +160,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index(0, 0)); make_multi_index(0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_block_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>, Sequence<KPerBlock, NPerBlock>,
...@@ -219,7 +219,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -219,7 +219,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}); Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
const auto block_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_k_m_block_mtx_desc), decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc), decltype(b_k_n_block_mtx_desc),
...@@ -256,14 +256,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -256,14 +256,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
#if 1 #if 1
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
} }
#endif #endif
#if 1
Float* p_a_block_even = p_a_block_double; Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double; Float* p_b_block_even = p_b_block_double;
...@@ -275,65 +276,66 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -275,65 +276,66 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
k_block_data_begin += 2 * KPerBlock) k_block_data_begin += 2 * KPerBlock)
{ {
// even iteration // even iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration // odd iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_even); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
} }
#endif
#if 1 #if 1
// LDS double buffer: tail // LDS double buffer: tail
{ {
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
{ {
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size, p_b_block_double + b_block_space_size,
p_c_thread); p_c_thread);
} }
...@@ -342,7 +344,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -342,7 +344,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
} }
} }
#endif #endif
...@@ -361,7 +363,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -361,7 +363,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// calculate origin of thread input tensor on global memory // calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
block_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global = const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row; m_block_data_on_global + c_thread_mtx_on_block.row;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment