"src/include/ConstantMergedTensorDescriptor.hpp" did not exist on "acd7082fe109aa4228dfca652e87cab96bc6837f"
Commit 03aa52bc authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent db4afa69
...@@ -39,7 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -39,7 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_2ho_2wo_k1_global_desc, const DynamicTensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
...@@ -66,6 +66,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -66,6 +66,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto Hox2 = Ho * 2;
const auto Wox2 = Wo * 2;
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0); const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
...@@ -146,18 +149,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -146,18 +149,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor // add tensor
const auto add_k_n_2hop_2wop_global_desc = transform_dynamic_tensor_descriptor( const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pad_transform(2 * Ho, 0, AddRightPadH), make_pad_transform(Hox2, 0, AddRightPadH),
make_pad_transform(2 * Wo, 0, AddRightPadW)), make_pad_transform(Wox2, 0, AddRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X; const auto E = C * Y * X;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
...@@ -209,7 +210,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -209,7 +210,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(add_k_n_2hop_2wop_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
...@@ -269,7 +270,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -269,7 +270,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
...@@ -285,7 +286,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -285,7 +286,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_2hop_2wop_global_desc, add_k_n_hopx2_wopx2_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
...@@ -300,7 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -300,7 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
...@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_2hop_2wop_global_desc, add_k_n_hopx2_wopx2_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
...@@ -331,7 +332,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -331,7 +332,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
...@@ -347,7 +348,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -347,7 +348,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_2hop_2wop_global_desc, add_k_n_hopx2_wopx2_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
...@@ -362,7 +363,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -362,7 +363,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
...@@ -378,7 +379,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -378,7 +379,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_2hop_2wop_global_desc, add_k_n_hopx2_wopx2_global_desc,
p_d_global, p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
......
...@@ -74,7 +74,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -74,7 +74,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -89,7 +89,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -89,7 +89,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1); const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
...@@ -148,10 +147,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -148,10 +147,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto d_k_n_2ho_2wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<2 * HoPerThread>{}, Number<2 * WoPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc), decltype(a_e_k_block_desc),
...@@ -358,29 +353,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -358,29 +353,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_c_thread); p_c_thread);
} }
#endif #endif
// output: register to global memory
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
#if 1 const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
FloatC p_d_thread[d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize()]; const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
threadwise_matrix_set_zero_v3(d_k_n_2ho_2wo_thread_desc, p_d_thread);
const index_t ho2_thread_data_on_global = const index_t hox2_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread * 2; hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wo2_thread_data_on_global = const index_t wox2_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread * 2; wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
const index_t k_thread_data_on_global = const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread; k_block_data_on_global + k_thread_id * KPerThread;
{ constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThread>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; FloatC p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
#if 1
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC, FloatC,
FloatC, FloatC,
decltype(d_k_n_2ho_2wo_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_2ho_2wo_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>, Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
...@@ -388,36 +392,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -388,36 +392,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>( true>(d_k_n_hox2_wox2_global_desc,
d_k_n_2ho_2wo_global_desc, make_multi_index(k_thread_data_on_global,
make_multi_index( 0,
k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global)) hox2_thread_data_on_global,
.Run(d_k_n_2ho_2wo_global_desc, wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global, p_d_global,
d_k_n_2ho_2wo_thread_desc, d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
}
for(index_t i = 0; i < d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize(); i++)
{
p_d_thread[i] += p_c_thread[i / 2];
}
#endif #endif
#if 1 #if 1
// output: register to global memory for(index_t k_i = 0; k_i < KPerThread; ++k_i)
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; {
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{
p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))] +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))];
}
}
}
#endif
#if 1
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC, FloatC,
FloatC, FloatC,
decltype(d_k_n_2ho_2wo_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_2ho_2wo_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>, Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
...@@ -425,18 +436,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -425,18 +436,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>( true>(d_k_n_hox2_wox2_global_desc,
d_k_n_2ho_2wo_global_desc, make_multi_index(k_thread_data_on_global,
make_multi_index( 0,
k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global)) hox2_thread_data_on_global,
.Run(d_k_n_2ho_2wo_thread_desc, wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
d_k_n_2ho_2wo_global_desc, d_k_n_hox2_wox2_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif #endif
}
} }
// pass tensor descriptor by reference // pass tensor descriptor by reference
...@@ -445,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -445,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -460,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -460,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_2ho_2wo_global_desc, d_k_n_hox2_wox2_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
...@@ -475,7 +487,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -475,7 +487,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -490,7 +502,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -490,7 +502,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_2ho_2wo_global_desc, d_k_n_hox2_wox2_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
...@@ -504,7 +516,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -504,7 +516,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc, const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -521,14 +533,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -521,14 +533,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_2ho_2wo_global_desc, d_k_n_hox2_wox2_global_desc,
p_d_global, p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
}; }; // namespace ck
} // namespace ck } // namespace ck
#endif #endif
...@@ -22,9 +22,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -22,9 +22,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
WeiDesc, WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x, const Tensor<TInWei>& wei_k_c_y_x,
AddDesc, AddDesc,
const Tensor<TOut>& add_n_k_2ho_2wo, const Tensor<TOut>& add_n_k_hox2_wox2,
OutDesc, OutDesc,
Tensor<TOut>& out_n_k_ho_wo, Tensor<TOut>& out_n_k_hox2_wox2,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
...@@ -38,8 +38,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -38,8 +38,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem add_n_k_2ho_2wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace()); DeviceMem add_n_k_hox2_wox2_device_buf(sizeof(TOut) *
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace()); add_n_k_hox2_wox2.mDesc.GetElementSpace());
DeviceMem out_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -56,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -56,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto Ho = OutDesc::GetLengths()[I2]; constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3]; constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Hox2 = Ho * 2;
constexpr auto Wox2 = Wo * 2;
constexpr auto Y = WeiDesc::GetLengths()[I2]; constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3]; constexpr auto X = WeiDesc::GetLengths()[I3];
...@@ -71,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -71,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc = const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc = const auto out_n_k_hox2_wox2_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
...@@ -86,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -86,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_2ho_2wo_k1_desc = const auto add_n_k0_hox2_wox2_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
...@@ -99,10 +104,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -99,10 +104,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{}))); make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor( Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{}))); make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> add_n_k0_2ho_2wo_k1(make_HostTensorDescriptor( Tensor<TOut> add_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{}))); make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor( Tensor<TOut> out_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{}))); make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
...@@ -115,17 +120,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -115,17 +120,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
}; };
auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) { auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) {
add_n_k0_2ho_2wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) = add_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_2ho_2wo(n, k, ho, wo); add_n_k_hox2_wox2(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Ho, Wo)(); make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Hox2, Wox2)();
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
add_n_k_2ho_2wo_device_buf.ToDevice(add_n_k0_2ho_2wo_k1.mData.data()); add_n_k_hox2_wox2_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
#if 1 #if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
...@@ -141,8 +146,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -141,8 +146,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
...@@ -205,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -205,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver.Run(wei_k_c0_y_x_desc, conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc, in_n_c0_hi_wi_desc,
add_n_k0_2ho_2wo_k1_desc, add_n_k0_hox2_wox2_k1_desc,
out_n_k0_ho_wo_k1_desc, out_n_k0_ho_wo_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
...@@ -215,18 +220,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -215,18 +220,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k_2ho_2wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); static_cast<TOut*>(out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data());
#if 0 #if 1
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) = out_n_k_hox2_wox2(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); out_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
}; };
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Hox2, Wox2)();
#endif #endif
} }
...@@ -41,14 +41,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -41,14 +41,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
} }
} }
} }
out_nkhw(n, k, ho, wo) = v + add_nkhw(n, k, ho, wo);
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
out_nkhw(n, k, hox2, wox2) = v + add_nkhw(n, k, hox2, wox2);
out_nkhw(n, k, hox2, wox2 + 1) = v + add_nkhw(n, k, hox2, wox2 + 1);
out_nkhw(n, k, hox2 + 1, wox2) = v + add_nkhw(n, k, hox2 + 1, wox2);
out_nkhw(n, k, hox2 + 1, wox2 + 1) = v + add_nkhw(n, k, hox2 + 1, wox2 + 1);
}; };
auto f_par = make_ParallelTensorFunctor(f, auto f_par = make_ParallelTensorFunctor(f,
out_nkhw.mDesc.GetLengths()[0], out_nkhw.mDesc.GetLengths()[0],
out_nkhw.mDesc.GetLengths()[1], out_nkhw.mDesc.GetLengths()[1],
out_nkhw.mDesc.GetLengths()[2], out_nkhw.mDesc.GetLengths()[2] / 2,
out_nkhw.mDesc.GetLengths()[3]); out_nkhw.mDesc.GetLengths()[3] / 2);
f_par(std::thread::hardware_concurrency()); f_par(std::thread::hardware_concurrency());
} }
......
...@@ -88,7 +88,7 @@ int main(int argc, char* argv[]) ...@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
...@@ -700,7 +700,8 @@ int main(int argc, char* argv[]) ...@@ -700,7 +700,8 @@ int main(int argc, char* argv[])
}; };
wei_kcyx.GenerateTensorValue(gen_wei, num_thread); wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif #endif
add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread); // add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
add_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
} }
#if 0 #if 0
...@@ -806,7 +807,7 @@ int main(int argc, char* argv[]) ...@@ -806,7 +807,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 #if 1
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment