Commit 0c883faa authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed outputpad

parent 351c227a
...@@ -75,8 +75,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -75,8 +75,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const auto OutRightPadH = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock - Ho; const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto OutRightPadW = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock - Wo; const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
...@@ -111,8 +114,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -111,8 +114,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -120,13 +123,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -120,13 +123,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Ho), make_pass_through_transform(Hop),
make_pass_through_transform(Wo)), make_pass_through_transform(Wop)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -137,12 +140,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -137,12 +140,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto E = C * Y * X; const auto E = C * Y * X;
const int Ho_new = out_k_n_ho_wo_global_desc.GetLength(I2); std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
const int Wo_new = out_k_n_ho_wo_global_desc.GetLength(I3);
std::cerr << "Ho_new = " << Ho_new << " Wo_new = " << Wo_new << std::endl;
if(!((K % KPerBlock) == 0 && (Ho_new % HoPerBlock) == 0 && (Wo_new % WoPerBlock) == 0 && if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E % EPerBlock) == 0)) (E % EPerBlock) == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
...@@ -190,7 +190,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -190,7 +190,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_hop_wop_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -221,7 +221,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -221,7 +221,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
...@@ -243,15 +243,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -243,15 +243,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
{ {
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel =
decltype(wei_e_k_global_desc), run_gridwise_operation<gridwise_gemm,
const FloatAB*, decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), const FloatAB*,
FloatC*, decltype(out_k_n_hop_wop_global_desc),
integral_constant<bool, true>, FloatC*,
integral_constant<bool, true>>; integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
...@@ -262,22 +263,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -262,22 +263,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel =
decltype(wei_e_k_global_desc), run_gridwise_operation<gridwise_gemm,
const FloatAB*, decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), const FloatAB*,
FloatC*, decltype(out_k_n_hop_wop_global_desc),
integral_constant<bool, true>, FloatC*,
integral_constant<bool, false>>; integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
...@@ -288,22 +290,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -288,22 +290,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel =
decltype(wei_e_k_global_desc), run_gridwise_operation<gridwise_gemm,
const FloatAB*, decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), const FloatAB*,
FloatC*, decltype(out_k_n_hop_wop_global_desc),
integral_constant<bool, false>, FloatC*,
integral_constant<bool, true>>; integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
...@@ -314,22 +317,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -314,22 +317,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel =
decltype(wei_e_k_global_desc), run_gridwise_operation<gridwise_gemm,
const FloatAB*, decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), const FloatAB*,
FloatC*, decltype(out_k_n_hop_wop_global_desc),
integral_constant<bool, false>, FloatC*,
integral_constant<bool, false>>; integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
...@@ -340,7 +344,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -340,7 +344,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -36,11 +36,10 @@ int main(int argc, char* argv[]) ...@@ -36,11 +36,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
//constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t HI = 544;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
constexpr index_t Y = 1; constexpr index_t Y = 1;
...@@ -107,7 +106,7 @@ int main(int argc, char* argv[]) ...@@ -107,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment