Commit ff224fcd authored by Jing Zhang's avatar Jing Zhang
Browse files

add output padding

parent c0b9d8c2
......@@ -75,14 +75,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
#if 0
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
#else
const auto OutRightPadH = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock - Ho;
const auto OutRightPadW = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
#endif
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -108,7 +124,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
......@@ -118,7 +134,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
#if 0
const auto out_k_n_ho_wo_global_desc =
transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
......@@ -126,11 +144,26 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#else
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
E % EPerBlock == 0))
const int Ho_new = out_k_n_ho_wo_global_desc.GetLength(I2);
const int Wo_new = out_k_n_ho_wo_global_desc.GetLength(I3);
std::cerr << "Ho_new = " << Ho_new << " Wo_new = " << Wo_new << std::endl;
if(!((K % KPerBlock) == 0 && (Ho_new % HoPerBlock) == 0 && (Wo_new % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
......@@ -175,9 +208,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc),
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
......@@ -230,108 +263,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
wei_e_k_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
in_e_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
wei_e_k_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
in_e_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
wei_e_k_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
in_e_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
wei_e_k_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
in_e_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
......
......@@ -82,10 +82,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
......@@ -111,6 +111,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64;
......@@ -135,6 +136,31 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#endif
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
......@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
......@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment