Commit c1159e3c authored by Jing Zhang's avatar Jing Zhang
Browse files

use vec

parent 25b71afc
...@@ -39,15 +39,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -39,15 +39,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -58,18 +56,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -58,18 +56,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); const auto K0 = out_n_k0_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); const auto Ho = out_n_k0_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); const auto Wo = out_n_k0_ho_wo_global_desc.GetLength(I3);
const auto Hox2 = Ho * 2;
const auto Wox2 = Wo * 2;
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0); const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
...@@ -141,21 +134,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -141,21 +134,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// output tensor // output tensor
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor( const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_pass_through_transform(K0),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH), make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)), make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor
const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2)),
make_tuple(make_pass_through_transform(K0),
make_pass_through_transform(N),
make_pad_transform(Hox2, 0, AddRightPadH),
make_pad_transform(Wox2, 0, AddRightPadW)),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
...@@ -192,17 +175,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -192,17 +175,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -210,7 +193,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -210,7 +193,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
...@@ -270,8 +252,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -270,8 +252,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -286,8 +266,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -286,8 +266,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
...@@ -301,8 +279,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -301,8 +279,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -317,8 +293,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -317,8 +293,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
...@@ -332,8 +306,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -332,8 +306,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
...@@ -348,8 +320,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -348,8 +320,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
...@@ -363,8 +333,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -363,8 +333,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
...@@ -379,8 +347,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -379,8 +347,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc, out_k_n_hop_wop_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
...@@ -394,7 +360,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -394,7 +360,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc, wei_k_c_y_x_global_desc,
out_n_k0_ho_wo_k1_global_desc) / out_n_k0_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
......
...@@ -18,7 +18,6 @@ template <index_t BlockSize, ...@@ -18,7 +18,6 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename DGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
...@@ -48,7 +47,7 @@ template <index_t BlockSize, ...@@ -48,7 +47,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v3 struct GridwiseDynamicGemm_km_kn_mn_v2
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -74,8 +73,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -74,8 +73,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
...@@ -353,183 +350,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -353,183 +350,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#endif #endif
// output: register to global memory // output: register to global memory
#if 0
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<1>{}, Number<1>{}, Number<1>{}, Number<1>{}));
constexpr auto vector_len = CThreadTransferDstScalarPerVector;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
vector_type<int8_t, vector_len> d_vec;
for(index_t k_i = 0; k_i < KPerThreadAdd; ++k_i)
{
for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
{
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
decltype(d_vec),
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<1, 1, 1, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add + k_i,
0,
hox2_thread_data_on_global + h_i,
wox2_thread_data_on_global + w_i))
.Run2(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, vector_len, 1>{}([&](auto i) {
d_vec.template AsType<int8_t>()(i) +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
});
ThreadwiseDynamicTensorSliceTransfer_v1r3<
decltype(d_vec),
FloatC,
decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc),
Sequence<1, 1, 1, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add + k_i,
0,
hox2_thread_data_on_global + h_i,
wox2_thread_data_on_global + w_i))
.Run2(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
d_k_n_hox2_wox2_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
}
}
}
#else
{ {
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
static_assert(CThreadTransferDstScalarPerVector == 16, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add = static_assert(CThreadTransferDstScalarPerVector == 16 && KPerBlock == 16, "");
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector; const index_t k_block_data_on_global_vec =
const index_t k_thread_data_on_global_add = k_block_work_id * (KPerBlock / CThreadTransferDstScalarPerVector);
k_block_data_on_global_add + k_thread_id * KPerThreadAdd; const index_t KPerThreadVec = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_vec =
k_block_data_on_global_vec + k_thread_id * KPerThreadVec;
constexpr auto d_k_n_hox2_wox2_thread_desc = constexpr auto c_k_n_ho_wo_thread_desc_vec =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{}, make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadVec>{},
Number<1>{}, Number<1>{},
Number<HoPerThreadx2>{}, Number<HoPerThread>{},
Number<WoPerThreadx2>{})); Number<WoPerThread>{}));
constexpr auto vec_len = d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize() * static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, "");
CThreadTransferDstScalarPerVector;
static_assert(vec_len == 256, ""); FloatC d_vec[c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize()];
// vector_type<int8_t, vec_len> d_vec;
FloatC d_vec[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v2< static_for<0, KPerThreadVec, 1>{}([&](auto k_i) {
FloatC, static_for<0, HoPerThread, 1>{}([&](auto h_i) {
// decltype(d_vec), static_for<0, WoPerThread, 1>{}([&](auto w_i) {
FloatC,
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, KPerThreadAdd, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
vector_type<int8_t, CThreadTransferDstScalarPerVector> t; vector_type<int8_t, CThreadTransferDstScalarPerVector> t;
// t.template AsType<FloatC>()(Number<0>{}) = d_vec.template AsType< // t.template AsType<FloatC>()(Number<0>{}) = d_vec.template AsType<
// FloatC>()[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset( // FloatC>()[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
// make_tuple(k_i, 0, h_i, w_i))>{}]; // make_tuple(k_i, 0, h_i, w_i))>{}];
t.template AsType<FloatC>()(Number<0>{}) = t.template AsType<FloatC>()(Number<0>{}) =
d_vec[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset( d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))>{}]; make_tuple(k_i, 0, h_i, w_i))>{}];
static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) { static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) {
t.template AsType<int8_t>()(i) += t.template AsType<int8_t>()(i) =
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset( p_c_thread[c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
make_tuple(k_i * CThreadTransferDstScalarPerVector + i, make_tuple(k_i * CThreadTransferDstScalarPerVector + i,
0, 0,
h_i / 2, h_i / 2,
...@@ -537,22 +394,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -537,22 +394,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
}); });
// d_vec.template AsType<FloatC>()( // d_vec.template AsType<FloatC>()(
// Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(make_tuple( // Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
// k_i, 0, h_i, w_i))>{}) = t.template AsType<FloatC>()[Number<0>{}]; // k_i, 0, h_i, w_i))>{}) = t.template AsType<FloatC>()[Number<0>{}];
d_vec[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(make_tuple( d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
k_i, 0, h_i, w_i))>{}] = t.template AsType<FloatC>()[Number<0>{}]; k_i, 0, h_i, w_i))>{}] = t.template AsType<FloatC>()[Number<0>{}];
}); });
}); });
}); });
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
// decltype(d_vec),
FloatC, FloatC,
FloatC, FloatC,
decltype(d_k_n_hox2_wox2_thread_desc), decltype(c_k_n_ho_wo_thread_desc_vec),
decltype(d_k_n_hox2_wox2_global_desc), decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>, Sequence<KPerThreadVec, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector, // CThreadTransferDstScalarPerVector,
...@@ -561,19 +417,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -561,19 +417,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(d_k_n_hox2_wox2_global_desc, true>(c_k_n_ho_wo_global_desc,
make_multi_index(k_thread_data_on_global_add, make_multi_index(k_thread_data_on_global_vec,
0, 0,
hox2_thread_data_on_global, ho_thread_data_on_global,
wox2_thread_data_on_global)) wo_thread_data_on_global))
.Run(d_k_n_hox2_wox2_thread_desc, .Run(c_k_n_ho_wo_thread_desc_vec,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
d_vec, d_vec,
d_k_n_hox2_wox2_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
#endif
} }
// pass tensor descriptor by reference // pass tensor descriptor by reference
...@@ -582,8 +437,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -582,8 +437,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -597,8 +450,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -597,8 +450,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
p_shared_block, p_shared_block,
...@@ -612,8 +463,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -612,8 +463,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -627,8 +476,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -627,8 +476,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
...@@ -641,8 +488,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -641,8 +488,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc, const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -658,8 +503,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -658,8 +503,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global, p_a_global,
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
...@@ -10,7 +10,6 @@ template <class TInWei, ...@@ -10,7 +10,6 @@ template <class TInWei,
class TOut, class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class AddDesc,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
...@@ -21,10 +20,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -21,10 +20,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const Tensor<TInWei>& in_n_c_hi_wi, const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc, WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x, const Tensor<TInWei>& wei_k_c_y_x,
AddDesc,
const Tensor<TOut>& add_n_k_hox2_wox2,
OutDesc, OutDesc,
Tensor<TOut>& out_n_k_hox2_wox2, Tensor<TOut>& out_n_k_ho_wo,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
...@@ -38,10 +35,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -38,10 +35,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem add_n_k_hox2_wox2_device_buf(sizeof(TOut) * DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
add_n_k_hox2_wox2.mDesc.GetElementSpace());
DeviceMem out_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -58,9 +52,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -58,9 +52,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto Ho = OutDesc::GetLengths()[I2]; constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3]; constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Hox2 = Ho * 2;
constexpr auto Wox2 = Wo * 2;
constexpr auto Y = WeiDesc::GetLengths()[I2]; constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3]; constexpr auto X = WeiDesc::GetLengths()[I3];
...@@ -76,7 +67,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -76,7 +67,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc = const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_hox2_wox2_desc = const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
...@@ -89,10 +80,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -89,10 +80,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
const auto wei_k_c0_y_x_desc = const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo));
const auto add_n_k0_hox2_wox2_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
...@@ -104,10 +93,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -104,10 +93,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{}))); make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor( Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{}))); make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> add_n_k0_hox2_wox2_k1(make_HostTensorDescriptor( Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{}))); make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
Tensor<TOut> out_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
...@@ -119,18 +106,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -119,18 +106,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x(k, c, y, x); wei_k_c_y_x(k, c, y, x);
}; };
auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) {
add_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_hox2_wox2(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Hox2, Wox2)();
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
add_n_k_hox2_wox2_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
#if 1 #if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
...@@ -210,8 +190,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -210,8 +190,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver.Run(wei_k_c0_y_x_desc, conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc, in_n_c0_hi_wi_desc,
add_n_k0_hox2_wox2_desc, out_n_k0_ho_wo_desc,
out_n_k0_ho_wo_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
...@@ -221,18 +200,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -221,18 +200,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>( static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()), out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
#if 1 #if 1
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_hox2_wox2(n, k, ho, wo) = out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
}; };
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Hox2, Wox2)(); make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
#endif #endif
} }
...@@ -10,7 +10,6 @@ template <class TIn, ...@@ -10,7 +10,6 @@ template <class TIn,
class UpperPads> class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
const Tensor<TOut>& add_nkhw,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
...@@ -41,21 +40,14 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -41,21 +40,14 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
} }
} }
} }
out_nkhw(n, k, ho, wo) = v;
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
out_nkhw(n, k, hox2, wox2) = v + add_nkhw(n, k, hox2, wox2);
out_nkhw(n, k, hox2, wox2 + 1) = v + add_nkhw(n, k, hox2, wox2 + 1);
out_nkhw(n, k, hox2 + 1, wox2) = v + add_nkhw(n, k, hox2 + 1, wox2);
out_nkhw(n, k, hox2 + 1, wox2 + 1) = v + add_nkhw(n, k, hox2 + 1, wox2 + 1);
}; };
auto f_par = make_ParallelTensorFunctor(f, auto f_par = make_ParallelTensorFunctor(f,
out_nkhw.mDesc.GetLengths()[0], out_nkhw.mDesc.GetLengths()[0],
out_nkhw.mDesc.GetLengths()[1], out_nkhw.mDesc.GetLengths()[1],
out_nkhw.mDesc.GetLengths()[2] / 2, out_nkhw.mDesc.GetLengths()[2],
out_nkhw.mDesc.GetLengths()[3] / 2); out_nkhw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency()); f_par(std::thread::hardware_concurrency());
} }
......
...@@ -625,12 +625,12 @@ int main(int argc, char* argv[]) ...@@ -625,12 +625,12 @@ int main(int argc, char* argv[])
constexpr auto Ho = out_nkhw_desc.GetLength(Number<2>{}); constexpr auto Ho = out_nkhw_desc.GetLength(Number<2>{});
constexpr auto Wo = out_nkhw_desc.GetLength(Number<3>{}); constexpr auto Wo = out_nkhw_desc.GetLength(Number<3>{});
auto add_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho * 2, Wo * 2>{}); // auto add_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho * 2, Wo * 2>{});
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
ostream_tensor_descriptor(add_nkhw_desc, std::cout << "add_nkhw_desc: "); // ostream_tensor_descriptor(add_nkhw_desc, std::cout << "add_nkhw_desc: ");
print_array("LeftPads", to_multi_index(LeftPads{})); print_array("LeftPads", to_multi_index(LeftPads{}));
print_array("RightPads", to_multi_index(RightPads{})); print_array("RightPads", to_multi_index(RightPads{}));
...@@ -661,10 +661,9 @@ int main(int argc, char* argv[]) ...@@ -661,10 +661,9 @@ int main(int argc, char* argv[])
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(add_nkhw_desc)); Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(add_nkhw_desc)); Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -700,8 +699,6 @@ int main(int argc, char* argv[]) ...@@ -700,8 +699,6 @@ int main(int argc, char* argv[])
}; };
wei_kcyx.GenerateTensorValue(gen_wei, num_thread); wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif #endif
// add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
add_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
} }
#if 0 #if 0
...@@ -783,8 +780,6 @@ int main(int argc, char* argv[]) ...@@ -783,8 +780,6 @@ int main(int argc, char* argv[])
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
add_nkhw_desc,
add_nkhw,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
...@@ -798,7 +793,6 @@ int main(int argc, char* argv[]) ...@@ -798,7 +793,6 @@ int main(int argc, char* argv[])
{ {
host_direct_convolution(in_nchw, host_direct_convolution(in_nchw,
wei_kcyx, wei_kcyx,
add_nkhw,
out_nkhw_host, out_nkhw_host,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment