Commit 0c883faa authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed outputpad

parent 351c227a
......@@ -75,8 +75,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto OutRightPadH = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock - Ho;
const auto OutRightPadW = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock - Wo;
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
......@@ -111,8 +114,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -120,13 +123,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
......@@ -137,12 +140,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto E = C * Y * X;
const int Ho_new = out_k_n_ho_wo_global_desc.GetLength(I2);
const int Wo_new = out_k_n_ho_wo_global_desc.GetLength(I3);
std::cerr << "Ho_new = " << Ho_new << " Wo_new = " << Wo_new << std::endl;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Ho_new % HoPerBlock) == 0 && (Wo_new % WoPerBlock) == 0 &&
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
......@@ -190,7 +190,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
......@@ -221,7 +221,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
......@@ -243,12 +243,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
......@@ -262,19 +263,20 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
......@@ -288,19 +290,20 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
......@@ -314,19 +317,20 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
......@@ -340,7 +344,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
......
......@@ -36,11 +36,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
//constexpr index_t HI = 540;
constexpr index_t HI = 544;
constexpr index_t HI = 540;
constexpr index_t WI = 960;
constexpr index_t K = 16;
constexpr index_t Y = 1;
......@@ -107,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment