Commit c1159e3c authored by Jing Zhang's avatar Jing Zhang
Browse files

use vec

parent 25b71afc
......@@ -39,15 +39,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
......@@ -58,18 +56,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto Hox2 = Ho * 2;
const auto Wox2 = Wo * 2;
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto Ho = out_n_k0_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_global_desc.GetLength(I3);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
......@@ -141,21 +134,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// output tensor
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo)),
make_tuple(make_pass_through_transform(K0),
make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor
const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2)),
make_tuple(make_pass_through_transform(K0),
make_pass_through_transform(N),
make_pad_transform(Hox2, 0, AddRightPadH),
make_pad_transform(Wox2, 0, AddRightPadW)),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
......@@ -192,17 +175,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize,
FloatAB,
FloatAcc,
......@@ -210,7 +193,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(add_k_n_hopx2_wopx2_global_desc),
decltype(out_k_n_hop_wop_global_desc),
KPerBlock,
HoPerBlock,
......@@ -270,8 +252,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
......@@ -286,8 +266,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
......@@ -301,8 +279,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
......@@ -317,8 +293,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
......@@ -332,8 +306,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
......@@ -348,8 +320,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
......@@ -363,8 +333,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
......@@ -379,8 +347,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
add_k_n_hopx2_wopx2_global_desc,
p_d_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
......@@ -394,7 +360,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k0_ho_wo_k1_global_desc) /
out_n_k0_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
......
......@@ -18,7 +18,6 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename DGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
......@@ -48,7 +47,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v3
struct GridwiseDynamicGemm_km_kn_mn_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -74,8 +73,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
......@@ -353,183 +350,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#endif
// output: register to global memory
#if 0
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<1>{}, Number<1>{}, Number<1>{}, Number<1>{}));
constexpr auto vector_len = CThreadTransferDstScalarPerVector;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
vector_type<int8_t, vector_len> d_vec;
for(index_t k_i = 0; k_i < KPerThreadAdd; ++k_i)
{
for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
{
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
decltype(d_vec),
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<1, 1, 1, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add + k_i,
0,
hox2_thread_data_on_global + h_i,
wox2_thread_data_on_global + w_i))
.Run2(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, vector_len, 1>{}([&](auto i) {
d_vec.template AsType<int8_t>()(i) +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
});
ThreadwiseDynamicTensorSliceTransfer_v1r3<
decltype(d_vec),
FloatC,
decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc),
Sequence<1, 1, 1, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add + k_i,
0,
hox2_thread_data_on_global + h_i,
wox2_thread_data_on_global + w_i))
.Run2(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
d_k_n_hox2_wox2_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
}
}
}
#else
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(CThreadTransferDstScalarPerVector == 16 && KPerBlock == 16, "");
const index_t k_block_data_on_global_vec =
k_block_work_id * (KPerBlock / CThreadTransferDstScalarPerVector);
const index_t KPerThreadVec = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_vec =
k_block_data_on_global_vec + k_thread_id * KPerThreadVec;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
static_assert(CThreadTransferDstScalarPerVector == 16, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{},
constexpr auto c_k_n_ho_wo_thread_desc_vec =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadVec>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
Number<HoPerThread>{},
Number<WoPerThread>{}));
constexpr auto vec_len = d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize() *
CThreadTransferDstScalarPerVector;
static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, "");
static_assert(vec_len == 256, "");
// vector_type<int8_t, vec_len> d_vec;
FloatC d_vec[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
FloatC d_vec[c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize()];
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
// decltype(d_vec),
FloatC,
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, KPerThreadAdd, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
static_for<0, KPerThreadVec, 1>{}([&](auto k_i) {
static_for<0, HoPerThread, 1>{}([&](auto h_i) {
static_for<0, WoPerThread, 1>{}([&](auto w_i) {
vector_type<int8_t, CThreadTransferDstScalarPerVector> t;
// t.template AsType<FloatC>()(Number<0>{}) = d_vec.template AsType<
// FloatC>()[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(
// FloatC>()[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
// make_tuple(k_i, 0, h_i, w_i))>{}];
t.template AsType<FloatC>()(Number<0>{}) =
d_vec[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(
d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))>{}];
static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) {
t.template AsType<int8_t>()(i) +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
t.template AsType<int8_t>()(i) =
p_c_thread[c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
make_tuple(k_i * CThreadTransferDstScalarPerVector + i,
0,
h_i / 2,
......@@ -537,22 +394,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
});
// d_vec.template AsType<FloatC>()(
// Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(make_tuple(
// Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
// k_i, 0, h_i, w_i))>{}) = t.template AsType<FloatC>()[Number<0>{}];
d_vec[Number<d_k_n_hox2_wox2_thread_desc.CalculateOffset(make_tuple(
d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
k_i, 0, h_i, w_i))>{}] = t.template AsType<FloatC>()[Number<0>{}];
});
});
});
ThreadwiseDynamicTensorSliceTransfer_v1r3<
// decltype(d_vec),
FloatC,
FloatC,
decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
decltype(c_k_n_ho_wo_thread_desc_vec),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThreadVec, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
......@@ -561,19 +417,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
true>(c_k_n_ho_wo_global_desc,
make_multi_index(k_thread_data_on_global_vec,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_thread_desc,
ho_thread_data_on_global,
wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc_vec,
make_tuple(I0, I0, I0, I0),
d_vec,
d_k_n_hox2_wox2_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
// pass tensor descriptor by reference
......@@ -582,8 +437,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -597,8 +450,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
......@@ -612,8 +463,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -627,8 +476,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
......@@ -641,8 +488,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -658,8 +503,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
d_k_n_hox2_wox2_global_desc,
p_d_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
......
......@@ -10,7 +10,6 @@ template <class TInWei,
class TOut,
class InDesc,
class WeiDesc,
class AddDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
......@@ -21,10 +20,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x,
AddDesc,
const Tensor<TOut>& add_n_k_hox2_wox2,
OutDesc,
Tensor<TOut>& out_n_k_hox2_wox2,
Tensor<TOut>& out_n_k_ho_wo,
ConvStrides,
ConvDilations,
InLeftPads,
......@@ -38,10 +35,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem add_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
DeviceMem out_n_k_hox2_wox2_device_buf(sizeof(TOut) *
add_n_k_hox2_wox2.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -58,9 +52,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Hox2 = Ho * 2;
constexpr auto Wox2 = Wo * 2;
constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3];
......@@ -76,7 +67,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_hox2_wox2_desc =
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{});
......@@ -89,10 +80,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_hox2_wox2_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2));
const auto out_n_k0_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
......@@ -104,10 +93,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> add_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
Tensor<TOut> out_n_k0_hox2_wox2_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Hox2, Wox2, K1>{})));
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
......@@ -119,18 +106,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x(k, c, y, x);
};
auto f_nkhw_to_nk0hwk1 = [&](auto n, auto k, auto ho, auto wo) {
add_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize) =
add_n_k_hox2_wox2(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw_to_nk0hwk1, N, K, Hox2, Wox2)();
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
add_n_k_hox2_wox2_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
......@@ -210,8 +190,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc,
add_n_k0_hox2_wox2_desc,
out_n_k0_ho_wo_k1_desc,
out_n_k0_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
......@@ -221,18 +200,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data());
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
#if 1
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_hox2_wox2(n, k, ho, wo) =
out_n_k0_hox2_wox2_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
};
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Hox2, Wox2)();
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
#endif
}
......@@ -10,7 +10,6 @@ template <class TIn,
class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
const Tensor<TOut>& add_nkhw,
Tensor<TOut>& out_nkhw,
ConvStrides,
ConvDilations,
......@@ -41,21 +40,14 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
}
}
}
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
out_nkhw(n, k, hox2, wox2) = v + add_nkhw(n, k, hox2, wox2);
out_nkhw(n, k, hox2, wox2 + 1) = v + add_nkhw(n, k, hox2, wox2 + 1);
out_nkhw(n, k, hox2 + 1, wox2) = v + add_nkhw(n, k, hox2 + 1, wox2);
out_nkhw(n, k, hox2 + 1, wox2 + 1) = v + add_nkhw(n, k, hox2 + 1, wox2 + 1);
out_nkhw(n, k, ho, wo) = v;
};
auto f_par = make_ParallelTensorFunctor(f,
out_nkhw.mDesc.GetLengths()[0],
out_nkhw.mDesc.GetLengths()[1],
out_nkhw.mDesc.GetLengths()[2] / 2,
out_nkhw.mDesc.GetLengths()[3] / 2);
out_nkhw.mDesc.GetLengths()[2],
out_nkhw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency());
}
......
......@@ -625,12 +625,12 @@ int main(int argc, char* argv[])
constexpr auto Ho = out_nkhw_desc.GetLength(Number<2>{});
constexpr auto Wo = out_nkhw_desc.GetLength(Number<3>{});
auto add_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho * 2, Wo * 2>{});
// auto add_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho * 2, Wo * 2>{});
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
ostream_tensor_descriptor(add_nkhw_desc, std::cout << "add_nkhw_desc: ");
// ostream_tensor_descriptor(add_nkhw_desc, std::cout << "add_nkhw_desc: ");
print_array("LeftPads", to_multi_index(LeftPads{}));
print_array("RightPads", to_multi_index(RightPads{}));
......@@ -661,10 +661,9 @@ int main(int argc, char* argv[])
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(add_nkhw_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -700,8 +699,6 @@ int main(int argc, char* argv[])
};
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif
// add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
add_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
}
#if 0
......@@ -783,8 +780,6 @@ int main(int argc, char* argv[])
in_nchw,
wei_kcyx_desc,
wei_kcyx,
add_nkhw_desc,
add_nkhw,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
......@@ -798,7 +793,6 @@ int main(int argc, char* argv[])
{
host_direct_convolution(in_nchw,
wei_kcyx,
add_nkhw,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment