Commit 03aa52bc authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent db4afa69
......@@ -39,7 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_2ho_2wo_k1_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
......@@ -66,6 +66,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto Hox2 = Ho * 2;
const auto Wox2 = Wo * 2;
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
......@@ -146,18 +149,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor
const auto add_k_n_2hop_2wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1)),
const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(2 * Ho, 0, AddRightPadH),
make_pad_transform(2 * Wo, 0, AddRightPadW)),
make_pad_transform(Hox2, 0, AddRightPadH),
make_pad_transform(Wox2, 0, AddRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
......@@ -209,7 +210,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(add_k_n_2hop_2wop_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
decltype(out_k_n_hop_wop_global_desc),
KPerBlock,
HoPerBlock,
......@@ -269,7 +270,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
......@@ -285,7 +286,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_2hop_2wop_global_desc,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
......@@ -300,7 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
......@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_2hop_2wop_global_desc,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
......@@ -331,7 +332,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
......@@ -347,7 +348,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_2hop_2wop_global_desc,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
......@@ -362,7 +363,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_2hop_2wop_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
......@@ -378,7 +379,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_2hop_2wop_global_desc,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
......
......@@ -74,7 +74,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
......@@ -89,7 +89,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
......@@ -148,10 +147,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto d_k_n_2ho_2wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<2 * HoPerThread>{}, Number<2 * WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
......@@ -358,29 +353,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_c_thread);
}
#endif
// output: register to global memory
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
#if 1
FloatC p_d_thread[d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize()];
threadwise_matrix_set_zero_v3(d_k_n_2ho_2wo_thread_desc, p_d_thread);
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t ho2_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread * 2;
const index_t wo2_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
{
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThread>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
FloatC p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
#if 1
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
FloatC,
decltype(d_k_n_2ho_2wo_global_desc),
decltype(d_k_n_2ho_2wo_thread_desc),
Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>,
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
......@@ -388,36 +392,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(
d_k_n_2ho_2wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global))
.Run(d_k_n_2ho_2wo_global_desc,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_2ho_2wo_thread_desc,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
p_d_thread,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
for(index_t i = 0; i < d_k_n_2ho_2wo_thread_desc.GetElementSpaceSize(); i++)
{
p_d_thread[i] += p_c_thread[i / 2];
}
#endif
#if 1
// output: register to global memory
for(index_t k_i = 0; k_i < KPerThread; ++k_i)
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
{
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{
p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))] +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))];
}
}
}
#endif
#if 1
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k_n_2ho_2wo_thread_desc),
decltype(d_k_n_2ho_2wo_global_desc),
Sequence<KPerThread, 1, 2 * HoPerThread, 2 * WoPerThread>,
decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc),
Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
......@@ -425,19 +436,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(
d_k_n_2ho_2wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho2_thread_data_on_global, wo2_thread_data_on_global))
.Run(d_k_n_2ho_2wo_thread_desc,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
p_d_thread,
d_k_n_2ho_2wo_global_desc,
d_k_n_hox2_wox2_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
......@@ -445,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
......@@ -460,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_2ho_2wo_global_desc,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
......@@ -475,7 +487,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
......@@ -490,7 +502,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_2ho_2wo_global_desc,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
......@@ -504,7 +516,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_2ho_2wo_global_desc,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
......@@ -521,14 +533,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_2ho_2wo_global_desc,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
}; // namespace ck
} // namespace ck
#endif
......@@ -22,9 +22,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x,
AddDesc,
const Tensor<TOut>& add_n_k_2ho_2wo,
const Tensor<TOut>& add_n_k_hox2_wox2,
OutDesc,
Tensor<TOut>& out_n_k_ho_wo,
Tensor<TOut>& out_n_k_hox2_wox2,
ConvStrides,
ConvDilations,
InLeftPads,
......@@ -38,8 +38,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem add_n_k_2ho_2wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * add_n_k_2ho_2wo.mDesc.GetElementSpace());
DeviceMem add_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
DeviceMem out_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -56,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Hox2 = Ho * 2;
constexpr auto Wox2 = Wo * 2;
constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3];
......@@ -71,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc =
const auto out_n_k_hox2_wox2_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{});
......@@ -86,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_2ho_2wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, 2 * Ho, 2 * Wo, K1));
const auto add_n_k0_hox2_wox2_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
......@@ -99,10 +104,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> add_n_k0_2ho_2wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{})));
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, 2 * Ho, 2 * Wo, K1>{})));
Tensor<TOut> add_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
Tensor<TOut> out_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
......@@ -115,17 +120,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
};
auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) {
add_n_k0_2ho_2wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_2ho_2wo(n, k, ho, wo);
add_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_hox2_wox2(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Ho, Wo)();
make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Hox2, Wox2)();
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
add_n_k_2ho_2wo_device_buf.ToDevice(add_n_k0_2ho_2wo_k1.mData.data());
add_n_k_hox2_wox2_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
......@@ -141,8 +146,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......@@ -205,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc,
add_n_k0_2ho_2wo_k1_desc,
add_n_k0_hox2_wox2_k1_desc,
out_n_k0_ho_wo_k1_desc,
conv_strides,
conv_dilations,
......@@ -215,18 +220,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k_2ho_2wo_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
static_cast<TOut*>(add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data());
#if 0
#if 1
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
out_n_k_hox2_wox2(n, k, ho, wo) =
out_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
};
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Hox2, Wox2)();
#endif
}
......@@ -41,14 +41,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
}
}
}
out_nkhw(n, k, ho, wo) = v + add_nkhw(n, k, ho, wo);
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
out_nkhw(n, k, hox2, wox2) = v + add_nkhw(n, k, hox2, wox2);
out_nkhw(n, k, hox2, wox2 + 1) = v + add_nkhw(n, k, hox2, wox2 + 1);
out_nkhw(n, k, hox2 + 1, wox2) = v + add_nkhw(n, k, hox2 + 1, wox2);
out_nkhw(n, k, hox2 + 1, wox2 + 1) = v + add_nkhw(n, k, hox2 + 1, wox2 + 1);
};
auto f_par = make_ParallelTensorFunctor(f,
out_nkhw.mDesc.GetLengths()[0],
out_nkhw.mDesc.GetLengths()[1],
out_nkhw.mDesc.GetLengths()[2],
out_nkhw.mDesc.GetLengths()[3]);
out_nkhw.mDesc.GetLengths()[2] / 2,
out_nkhw.mDesc.GetLengths()[3] / 2);
f_par(std::thread::hardware_concurrency());
}
......
......@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
......@@ -700,7 +700,8 @@ int main(int argc, char* argv[])
};
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif
add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
// add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
add_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
}
#if 0
......@@ -806,7 +807,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device);
#if 0
#if 1
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment