Commit db4afa69 authored by Jing Zhang's avatar Jing Zhang
Browse files

demo of upsampling

parent 3dd0cc31
...@@ -168,7 +168,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -168,7 +168,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
#if 1 #if 0
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize, BlockSize,
......
...@@ -31,6 +31,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -31,6 +31,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
{ {
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
typename... Add,
typename... Out, typename... Out,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
...@@ -38,6 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -38,6 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_2ho_2wo_k1_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
...@@ -82,6 +84,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -82,6 +84,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto OutRightPadH = Hop - Ho; const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo; const auto OutRightPadW = Wop - Wo;
const auto AddRightPadH = 2 * OutRightPadH;
const auto AddRightPadW = 2 * OutRightPadW;
const auto InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
...@@ -92,6 +97,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -92,6 +97,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
<< std::endl; << std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl; << std::endl;
std::cerr << "AddRightPadH = " << AddRightPadH << " AddRightPadW = " << AddRightPadW
<< std::endl;
// weight tensor // weight tensor
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
...@@ -139,6 +146,18 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -139,6 +146,18 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor
const auto add_k_n_2hop_2wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(2 * Ho, 0, AddRightPadH),
make_pad_transform(2 * Wo, 0, AddRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X; const auto E = C * Y * X;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
...@@ -190,6 +209,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -190,6 +209,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(add_k_n_2hop_2wop_global_desc),
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
...@@ -249,8 +269,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -249,8 +269,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(add_k_n_2hop_2wop_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -264,8 +285,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -264,8 +285,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_hop_wop_global_desc, add_k_n_2hop_2wop_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -278,8 +300,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -278,8 +300,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(add_k_n_2hop_2wop_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -293,8 +316,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -293,8 +316,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_hop_wop_global_desc, add_k_n_2hop_2wop_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
...@@ -307,8 +331,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -307,8 +331,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(add_k_n_2hop_2wop_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -322,8 +347,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -322,8 +347,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_hop_wop_global_desc, add_k_n_2hop_2wop_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -336,8 +362,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -336,8 +362,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(add_k_n_2hop_2wop_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -351,8 +378,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -351,8 +378,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_hop_wop_global_desc, add_k_n_2hop_2wop_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -18,6 +18,7 @@ template <index_t BlockSize, ...@@ -18,6 +18,7 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename DGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
...@@ -73,8 +74,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -73,8 +74,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -146,6 +148,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -146,6 +148,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto d_k_n_2ho_2wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<2 * HoPerThread>{}, Number<2 * WoPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc), decltype(a_e_k_block_desc),
...@@ -354,22 +360,27 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -354,22 +360,27 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#endif #endif
#if 1 #if 1
FloatC p_d_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; FloatC p_d_thread[d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize()];
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_d_thread); threadwise_matrix_set_zero_v3(d_k_n_2ho_2wo_thread_desc, p_d_thread);
const index_t ho2_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread * 2;
const index_t wo2_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread * 2;
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
{ {
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC, FloatC,
FloatC, FloatC,
decltype(c_k_n_ho_wo_global_desc), decltype(d_k_n_2ho_2wo_global_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(d_k_n_2ho_2wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>, Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
...@@ -378,20 +389,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -378,20 +389,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>( true>(
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
make_multi_index( make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global))
.Run(c_k_n_ho_wo_global_desc, .Run(d_k_n_2ho_2wo_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_thread_desc, d_k_n_2ho_2wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
for(index_t i = 0; i < c_k_n_ho_wo_thread_desc.GetElementSpaceSize(); i++) for(index_t i = 0; i < d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize(); i++)
{ {
p_d_thread[i] += p_c_thread[i]; p_d_thread[i] += p_c_thread[i / 2];
} }
#endif #endif
...@@ -401,15 +412,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -401,15 +412,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC, FloatC,
FloatC, FloatC,
decltype(c_k_n_ho_wo_thread_desc), decltype(d_k_n_2ho_2wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc), decltype(d_k_n_2ho_2wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>, Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
...@@ -418,13 +426,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -418,13 +426,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>( true>(
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
make_multi_index( make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc, .Run(d_k_n_2ho_2wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
...@@ -437,8 +445,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -437,8 +445,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -451,8 +460,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -451,8 +460,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
...@@ -465,8 +475,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -465,8 +475,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -479,8 +490,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -479,8 +490,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -492,8 +504,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -492,8 +504,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc, const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -508,8 +521,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -508,8 +521,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
c_k_n_ho_wo_global_desc, d_k_n_2ho_2wo_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
...@@ -10,6 +10,7 @@ template <class TInWei, ...@@ -10,6 +10,7 @@ template <class TInWei,
class TOut, class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class AddDesc,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
...@@ -20,8 +21,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -20,8 +21,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const Tensor<TInWei>& in_n_c_hi_wi, const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc, WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x, const Tensor<TInWei>& wei_k_c_y_x,
AddDesc,
const Tensor<TOut>& add_n_k_2ho_2wo,
OutDesc, OutDesc,
Tensor<TOut>& add_n_k_ho_wo,
Tensor<TOut>& out_n_k_ho_wo, Tensor<TOut>& out_n_k_ho_wo,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
...@@ -36,8 +38,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -36,8 +38,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem add_n_k_ho_wo_device_buf(sizeof(TOut) * add_n_k_ho_wo.mDesc.GetElementSpace()); DeviceMem add_n_k_2ho_2wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -84,6 +86,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -84,6 +86,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_2ho_2wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
...@@ -95,10 +99,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -95,10 +99,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{}))); make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor( Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{}))); make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> add_n_k0_ho_wo_k1(make_HostTensorDescriptor( Tensor<TOut> add_n_k0_2ho_2wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{}))); make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{})));
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor( Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{}))); make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
...@@ -111,8 +115,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -111,8 +115,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
}; };
auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) { auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) {
add_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) = add_n_k0_2ho_2wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_ho_wo(n, k, ho, wo); add_n_k_2ho_2wo(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
...@@ -121,7 +125,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -121,7 +125,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
add_n_k_ho_wo_device_buf.ToDevice(add_n_k0_ho_wo_k1.mData.data()); add_n_k_2ho_2wo_device_buf.ToDevice(add_n_k0_2ho_2wo_k1.mData.data());
#if 1 #if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
...@@ -176,20 +180,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -176,20 +180,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0 #if 0
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#else #else
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
#endif #endif
BlockSize, <BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, TAcc, TOut, KPerBlock, typename vector_type<TInWei, InWeiVectorSize>::type,
HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread, TAcc,
EPerThread, ABlockTransferThreadSliceLengths_E_K, TOut,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E, KPerBlock,
ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W, HoPerBlock,
CThreadTransferDstScalarPerVector_W > {}; WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_W,
CThreadTransferDstScalarPerVector_W>{};
conv_driver.Run(wei_k_c0_y_x_desc, conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc, in_n_c0_hi_wi_desc,
add_n_k0_2ho_2wo_k1_desc,
out_n_k0_ho_wo_k1_desc, out_n_k0_ho_wo_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
...@@ -199,15 +215,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -199,15 +215,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(add_n_k_2ho_2wo_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
#if 0
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) = out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
}; };
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
#endif
} }
...@@ -64,7 +64,7 @@ int main(int argc, char* argv[]) ...@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -622,9 +622,16 @@ int main(int argc, char* argv[]) ...@@ -622,9 +622,16 @@ int main(int argc, char* argv[])
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
constexpr auto Ho = out_nkhw_desc.GetLength(Number<2>{});
constexpr auto Wo = out_nkhw_desc.GetLength(Number<3>{});
auto add_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho * 2, Wo * 2>{});
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
ostream_tensor_descriptor(add_nkhw_desc, std::cout << "add_nkhw_desc: ");
print_array("LeftPads", to_multi_index(LeftPads{})); print_array("LeftPads", to_multi_index(LeftPads{}));
print_array("RightPads", to_multi_index(RightPads{})); print_array("RightPads", to_multi_index(RightPads{}));
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
...@@ -654,10 +661,10 @@ int main(int argc, char* argv[]) ...@@ -654,10 +661,10 @@ int main(int argc, char* argv[])
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(add_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -775,8 +782,9 @@ int main(int argc, char* argv[]) ...@@ -775,8 +782,9 @@ int main(int argc, char* argv[])
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, add_nkhw_desc,
add_nkhw, add_nkhw,
out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment