Commit 91ef99a7 authored by root's avatar root
Browse files

double buffer b with bug

parent 88d51698
...@@ -147,7 +147,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -147,7 +147,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
// loop over k // loop over k
for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop) for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop)
{ {
#if 1 #if 0
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) + a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, 0))); p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, 0)));
......
...@@ -219,9 +219,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -219,9 +219,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = math::integer_least_multiple(
b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
// register allocation for output // register allocation for output
...@@ -235,8 +232,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -235,8 +232,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// zero out threadwise output // zero out threadwise output
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread); // threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(CYXPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(CYXPerBlock, 0);
// constexpr auto b_block_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
...@@ -249,8 +246,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -249,8 +246,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
constexpr auto b_cyx_n_h_w_global_move_slice_window_iterator_hack = constexpr auto b_cyx_n_h_w_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
Float p_b_thread[b_cyx_n_h_w_thread_desc.GetElementSpaceSize()]; constexpr auto b_thread_space_size = b_cyx_n_h_w_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -260,27 +259,32 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -260,27 +259,32 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_cyx_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread, p_b_thread_double,
b_cyx_n_h_w_global_iterator_hacks); b_cyx_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double);
#if 0
__syncthreads(); __syncthreads();
//blockwise_gemm.Run(p_a_block_double, p_b_thread, p_c_thread); //blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
index_t sum = 0; index_t sum = 0;
for(index_t i = 0; i < b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); i++) for(index_t i = 0; i < b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); i++)
sum += p_b_thread[i]; sum += p_b_thread[i];
p_c_thread[0] = get_thread_local_1d_id() * 10000 + sum; p_c_thread[0] = get_thread_local_1d_id() * 10000 + sum;
#endif
} }
#if 0 #if 1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_a_block_even = p_a_block_double; Float* p_a_block_even = p_a_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size; Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_thread_even = p_b_thread_double;
Float* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
index_t b_block_data_begin = 0; index_t b_block_data_begin = 0;
...@@ -293,14 +297,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -293,14 +297,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc,
p_b_global,
b_cyx_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_odd,
b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_thread_even, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_odd);
...@@ -309,14 +323,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -309,14 +323,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc,
p_b_global,
b_cyx_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_even,
b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_thread_odd, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_even);
...@@ -332,13 +356,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -332,13 +356,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc,
p_b_global,
b_cyx_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size,
b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double + a_block_space_size);
...@@ -347,7 +381,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -347,7 +381,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size, p_b_thread_double + b_thread_space_size,
p_c_thread); p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
...@@ -355,7 +389,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -355,7 +389,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
} }
#endif #endif
......
...@@ -73,20 +73,20 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -73,20 +73,20 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8; constexpr index_t HPerBlock = 8;
constexpr index_t WPerBlock = 8; constexpr index_t WPerBlock = 8;
constexpr index_t CYXPerBlock = 4 * 3 * 3; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1; constexpr index_t HPerThread = 1;
constexpr index_t WPerThread = 1; constexpr index_t WPerThread = 1;
constexpr index_t CYXPerThread = 4 * 3 * 3; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<36, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment