Commit dadd5a91 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed e1

parent 8f520532
......@@ -290,6 +290,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_e0_e1_k_block_desc,
make_multi_index(0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0);
constexpr auto b_e0_e1_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
I1, Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
......@@ -336,6 +338,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
true>
b_thread_even_buf, b_thread_odd_buf;
const auto E0 = b_e0_e1_n_ho_wo_global_desc.GetLength(I0);
index_t e0_block_data_begin = 0;
do
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
......@@ -355,7 +363,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
if constexpr(HasMainKBlockLoop)
{
index_t e_block_data_begin = 0;
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
......@@ -393,9 +401,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock;
e1_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E1 - 2 * EPerBlock);
} while(e1_block_data_begin < E1 - 2 * EPerBlock);
}
// LDS double buffer: tail
......@@ -425,6 +433,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(
a_e0_e1_k_global_desc, a_block_slice_copy_step, AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
e0_block_data_begin += 1;
} while(e0_block_data_begin < E0);
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
......
......@@ -116,12 +116,12 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
using ABlockTransferThreadSliceLengths_E0_E1_K = Sequence<1, 4, 1>;
using ABlockTransferThreadClusterLengths_E0_E1_K = Sequence<1, 4, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 4;
constexpr index_t CThreadTransferDstScalarPerVector_K = 1;
constexpr index_t CThreadTransferDstScalarPerVector_K = 4;
#endif
constexpr auto conv_driver =
......
......@@ -191,7 +191,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e0_e1_n_ho_wo_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
......
......@@ -52,7 +52,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 8 8 1 1 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 8 8 1 1 1 1 0 0 0 0
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 8 8 1 1 1 1 0 0 0 0
################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment