Commit c21521a1 authored by Jing Zhang's avatar Jing Zhang
Browse files

added fusion into non-outpad

parent 514b2d1c
...@@ -45,6 +45,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -45,6 +45,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -236,6 +237,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -236,6 +237,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -250,6 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -250,6 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -262,6 +265,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -262,6 +265,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -276,6 +280,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -276,6 +280,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
...@@ -288,6 +293,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -288,6 +293,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -302,6 +308,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -302,6 +308,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -314,6 +321,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -314,6 +321,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -328,6 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -328,6 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -353,10 +353,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -353,10 +353,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
} }
#endif #endif
#if 1
FloatC p_d_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; FloatC p_d_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_d_thread); threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_d_thread);
#if 1
{ {
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
...@@ -388,12 +388,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -388,12 +388,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_d_thread, p_d_thread,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
#endif
for(index_t i = 0; i < c_k_n_ho_wo_thread_desc.GetElementSpaceSize(); i++) for(index_t i = 0; i < c_k_n_ho_wo_thread_desc.GetElementSpaceSize(); i++)
{ {
p_d_thread[i] += p_c_thread[i]; p_d_thread[i] += p_c_thread[i];
} }
#endif
#if 1 #if 1
// output: register to global memory // output: register to global memory
......
...@@ -137,8 +137,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -137,8 +137,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>; using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment