Commit dadd5a91 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed e1

parent 8f520532
...@@ -290,6 +290,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -290,6 +290,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_e0_e1_k_block_desc, a_e0_e1_k_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0);
constexpr auto b_e0_e1_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e0_e1_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
I1, Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); I1, Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
...@@ -336,48 +338,77 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -336,48 +338,77 @@ struct GridwiseGemmDlops_km_kn_mn_v3
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data const auto E0 = b_e0_e1_n_ho_wo_global_desc.GetLength(I0);
{
a_blockwise_copy.RunRead(
a_e0_e1_k_global_desc, a_global_buf, a_e0_e1_k_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k_block_desc, a_block_buf);
}
__syncthreads(); index_t e0_block_data_begin = 0;
if constexpr(HasMainKBlockLoop) do
{ {
index_t e_block_data_begin = 0; // LDS double buffer: preload data
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{ {
// even iteration a_blockwise_copy.RunRead(
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, a_e0_e1_k_global_desc, a_global_buf, a_e0_e1_k_global_step_hacks);
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data a_blockwise_copy.RunWrite(a_e0_e1_k_block_desc, a_block_buf);
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window }
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0)); __syncthreads();
if constexpr(HasMainKBlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
e1_block_data_begin += 2 * EPerBlock;
} while(e1_block_data_begin < E1 - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
...@@ -385,45 +416,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -385,45 +416,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock; // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(
a_e0_e1_k_global_desc, a_block_slice_copy_step, AGlobalMoveSliceWindowStepHacks{});
} while(e_block_data_begin < E1 - 2 * EPerBlock); blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0));
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, e0_block_data_begin += 1;
b_global_buf,
b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data } while(e0_block_data_begin < E0);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory // output: register to global memory
{ {
......
...@@ -116,12 +116,12 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -116,12 +116,12 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
using ABlockTransferThreadSliceLengths_E0_E1_K = Sequence<1, 4, 1>; using ABlockTransferThreadSliceLengths_E0_E1_K = Sequence<1, 4, 1>;
using ABlockTransferThreadClusterLengths_E0_E1_K = Sequence<1, 4, 16>; using ABlockTransferThreadClusterLengths_E0_E1_K = Sequence<1, 4, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 1; constexpr index_t BThreadTransferSrcScalarPerVector_E = 4;
constexpr index_t CThreadTransferDstScalarPerVector_K = 1; constexpr index_t CThreadTransferDstScalarPerVector_K = 4;
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
......
...@@ -191,7 +191,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -191,7 +191,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e0_e1_n_ho_wo_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_ho_wo_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
......
...@@ -52,7 +52,7 @@ REPEAT=$6 ...@@ -52,7 +52,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 8 8 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 8 8 1 1 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 8 8 1 1 1 1 0 0 0 0 ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 8 8 1 1 1 1 0 0 0 0
################################################ layout algo verify init log repeat M___ N___ K___ ################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 #./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment