Commit ff224fcd authored by Jing Zhang's avatar Jing Zhang
Browse files

add output padding

parent c0b9d8c2
...@@ -75,14 +75,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -75,14 +75,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
#if 0
const auto InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
#else
const auto OutRightPadH = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock - Ho;
const auto OutRightPadW = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
#endif
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -108,7 +124,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -108,7 +124,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -118,7 +134,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -118,7 +134,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( #if 0
const auto out_k_n_ho_wo_global_desc =
transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -126,11 +144,26 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -126,11 +144,26 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_pass_through_transform(Wo)), make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#else
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
const auto E = C * Y * X; const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 && const int Ho_new = out_k_n_ho_wo_global_desc.GetLength(I2);
E % EPerBlock == 0)) const int Wo_new = out_k_n_ho_wo_global_desc.GetLength(I3);
std::cerr << "Ho_new = " << Ho_new << " Wo_new = " << Wo_new << std::endl;
if(!((K % KPerBlock) == 0 && (Ho_new % HoPerBlock) == 0 && (Wo_new % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -175,9 +208,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -175,9 +208,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_e_k_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -230,108 +263,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -230,108 +263,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{ {
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, true>,
integral_constant<bool, true>, integral_constant<bool, true>>;
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, true>,
integral_constant<bool, true>, integral_constant<bool, false>>;
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, false>,
integral_constant<bool, false>, integral_constant<bool, true>>;
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, false>,
integral_constant<bool, false>, integral_constant<bool, false>>;
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -82,10 +82,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -82,10 +82,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor( Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
...@@ -111,6 +111,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -111,6 +111,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -135,6 +136,31 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -135,6 +136,31 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t CThreadTransferDstScalarPerVector_W = K1; constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#endif
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -106,7 +106,7 @@ int main(int argc, char* argv[]) ...@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment