Commit c21521a1 authored by Jing Zhang's avatar Jing Zhang
Browse files

added fusion into non-outpad

parent 514b2d1c
......@@ -45,6 +45,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
......@@ -236,6 +237,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
......@@ -250,6 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
......@@ -262,6 +265,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
......@@ -276,6 +280,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
......@@ -288,6 +293,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
......@@ -302,6 +308,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
......@@ -314,6 +321,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
......@@ -328,6 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
......
......@@ -353,10 +353,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
}
#endif
#if 1
FloatC p_d_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_d_thread);
#if 1
{
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
......@@ -388,12 +388,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_d_thread,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
for(index_t i = 0; i < c_k_n_ho_wo_thread_desc.GetElementSpaceSize(); i++)
{
p_d_thread[i] += p_c_thread[i];
}
#endif
#if 1
// output: register to global memory
......
......@@ -137,8 +137,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment