Commit 8e35a579 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 89123dd7
......@@ -131,7 +131,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy
auto a_block_copy =
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
......@@ -160,7 +160,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_block_copy =
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
......@@ -219,7 +219,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
const auto block_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
......@@ -256,14 +256,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
#if 1
// LDS double buffer: preload data into LDS
{
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
#endif
#if 1
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
......@@ -275,65 +276,66 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
k_block_data_begin += 2 * KPerBlock)
{
// even iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
}
#endif
#if 1
// LDS double buffer: tail
{
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
{
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size,
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
......@@ -342,7 +344,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
#endif
......@@ -361,7 +363,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
block_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment