"git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "61fe873c89b934b7e49d663f521eb5869229e980"
Commit 2558d019 authored by Chao Liu's avatar Chao Liu
Browse files

making dynamic multi-index transform support compile-time info

parent 1e55a3b1
...@@ -36,14 +36,20 @@ template <index_t BlockSize, ...@@ -36,14 +36,20 @@ template <index_t BlockSize,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, typename... In, typename... Out> template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides, const ConvStrides& conv_strides,
const MultiIndex<2> conv_dilations, const ConvDilations& conv_dilations,
const MultiIndex<2> in_left_pads, const InLeftPads& in_left_pads,
const MultiIndex<2> in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const Float* __restrict__ p_out_global) const
...@@ -53,30 +59,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -53,30 +59,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3); const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
...@@ -95,8 +101,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -95,8 +101,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(I2); const auto Hip = in_n_c_hip_wip_global_desc.GetLength(I2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(I3); const auto Wip = in_n_c_hip_wip_global_desc.GetLength(I3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
...@@ -123,9 +129,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -123,9 +129,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) GemmK % GemmKPerBlock == 0))
...@@ -133,21 +139,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -133,21 +139,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr auto GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr auto GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1; const auto GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1; const auto GemmN0 = GemmN / GemmN1;
#if 0
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)},
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#else
const auto GemmM0_GemmM1 = make_tuple(GemmM0, Number<GemmM1>{}); const auto GemmM0_GemmM1 = make_tuple(GemmM0, Number<GemmM1>{});
const auto GemmN0_GemmN1 = make_tuple(GemmN0, Number<GemmN1>{}); const auto GemmN0_GemmN1 = make_tuple(GemmN0, Number<GemmN1>{});
...@@ -159,7 +156,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -159,7 +156,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DynamicUnMerge<2, false, remove_cv_t<decltype(GemmN0_GemmN1)>>{GemmN0_GemmN1}), DynamicUnMerge<2, false, remove_cv_t<decltype(GemmN0_GemmN1)>>{GemmN0_GemmN1}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#endif
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_k_m_global_iterator_hacks =
...@@ -235,7 +231,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -235,7 +231,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
...@@ -724,14 +720,20 @@ template <index_t BlockSize, ...@@ -724,14 +720,20 @@ template <index_t BlockSize,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{ {
template <typename... Wei, typename... In, typename... Out> template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides, const ConvStrides& conv_strides,
const MultiIndex<2> conv_dilations, const ConvDilations& conv_dilations,
const MultiIndex<2> in_left_pads, const InLeftPads& in_left_pads,
const MultiIndex<2> in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const Float* __restrict__ p_out_global) const
...@@ -741,30 +743,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -741,30 +743,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3); const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0)) if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0))
{ {
...@@ -791,8 +793,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -791,8 +793,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
// debug: don't do padding // debug: don't do padding
const auto in_n_c_hip_wip_global_desc = in_n_c_hi_wi_global_desc; const auto in_n_c_hip_wip_global_desc = in_n_c_hi_wi_global_desc;
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(I2); const auto Hip = in_n_c_hip_wip_global_desc.GetLength(I2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(I3); const auto Wip = in_n_c_hip_wip_global_desc.GetLength(I3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
...@@ -828,9 +830,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -828,9 +830,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#endif #endif
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) GemmK % GemmKPerBlock == 0))
...@@ -838,11 +840,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -838,11 +840,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr auto GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr auto GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1; const auto GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1; const auto GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( transform_dynamic_tensor_descriptor(
...@@ -924,7 +926,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -924,7 +926,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
...@@ -1410,14 +1412,20 @@ template <index_t BlockSize, ...@@ -1410,14 +1412,20 @@ template <index_t BlockSize,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
template <typename... Wei, typename... In, typename... Out> template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides, const ConvStrides& conv_strides,
const MultiIndex<2> conv_dilations, const ConvDilations& conv_dilations,
const MultiIndex<2> in_left_pads, const InLeftPads& in_left_pads,
const MultiIndex<2> in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const Float* __restrict__ p_out_global) const
...@@ -1427,30 +1435,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1427,30 +1435,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3); const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0]; const auto InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1]; const auto InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
...@@ -1480,9 +1488,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1480,9 +1488,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) GemmK % GemmKPerBlock == 0))
...@@ -1490,11 +1498,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1490,11 +1498,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr auto GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr auto GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1; const auto GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1; const auto GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( transform_dynamic_tensor_descriptor(
...@@ -1574,7 +1582,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1574,7 +1582,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
......
...@@ -6,17 +6,20 @@ ...@@ -6,17 +6,20 @@
namespace ck { namespace ck {
template <typename LowLength = index_t>
struct DynamicPassThrough struct DynamicPassThrough
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
UpperIndex up_lengths_; using UpLengths = decltype(make_tuple(LowLength{}));
UpLengths up_lengths_;
__host__ __device__ constexpr DynamicPassThrough() = default; __host__ __device__ constexpr DynamicPassThrough() = default;
__host__ __device__ constexpr DynamicPassThrough(const index_t& low_length) __host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length)
: up_lengths_{make_multi_index(low_length)} : up_lengths_{make_tuple(low_length)}
{ {
} }
...@@ -75,27 +78,33 @@ struct DynamicPassThrough ...@@ -75,27 +78,33 @@ struct DynamicPassThrough
{ {
printf("{"); printf("{");
printf("DynamicPassThrough, "); printf("DynamicPassThrough, ");
printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("}"); printf("}");
} }
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false,
typename LowLength = index_t,
typename LeftPad = index_t,
typename RightPad = index_t>
struct DynamicPad struct DynamicPad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
UpperIndex up_lengths_; using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{}));
index_t left_pad_;
index_t right_pad_; UpLengths up_lengths_;
LeftPad left_pad_;
RightPad right_pad_;
__host__ __device__ constexpr DynamicPad() = default; __host__ __device__ constexpr DynamicPad() = default;
__host__ __device__ constexpr DynamicPad(const index_t& low_length, __host__ __device__ constexpr DynamicPad(const LowLength& low_length,
const index_t& left_pad, const LeftPad& left_pad,
const index_t& right_pad) const RightPad& right_pad)
: up_lengths_{make_multi_index(low_length + left_pad + right_pad)}, : up_lengths_{make_tuple(low_length + left_pad + right_pad)},
left_pad_{left_pad}, left_pad_{left_pad},
right_pad_{right_pad} right_pad_{right_pad}
{ {
...@@ -158,27 +167,30 @@ struct DynamicPad ...@@ -158,27 +167,30 @@ struct DynamicPad
{ {
printf("{"); printf("{");
printf("DynamicPad, "); printf("DynamicPad, ");
printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_); printf("left_pad_ %d", index_t{left_pad_});
printf(", "); printf("right_pad_ %d", index_t{right_pad_});
printf("right_pad_ %d", right_pad_);
printf("}"); printf("}");
} }
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false, typename LowLength = index_t, typename LeftPad = index_t>
struct DynamicLeftPad struct DynamicLeftPad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
UpperIndex up_lengths_; using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{}));
index_t left_pad_;
UpLengths up_lengths_;
LeftPad left_pad_;
__host__ __device__ constexpr DynamicLeftPad() = default; __host__ __device__ constexpr DynamicLeftPad() = default;
__host__ __device__ constexpr DynamicLeftPad(const index_t& low_length, const index_t& left_pad) __host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length,
: up_lengths_{make_multi_index(low_length + left_pad)}, left_pad_{left_pad} const LeftPad& left_pad)
: up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad}
{ {
} }
...@@ -238,27 +250,30 @@ struct DynamicLeftPad ...@@ -238,27 +250,30 @@ struct DynamicLeftPad
{ {
printf("{"); printf("{");
printf("DynamicLeftPad, "); printf("DynamicLeftPad, ");
printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_); printf("left_pad_ %d", index_t{left_pad_});
printf("}"); printf("}");
} }
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false, typename LowLength = index_t, typename RightPad = index_t>
struct DynamicRightPad struct DynamicRightPad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
UpperIndex up_lengths_; using UpLengths = decltype(make_tuple(LowLength{} + RightPad{}));
index_t low_length_;
index_t right_pad_; UpLengths up_lengths_;
LowLength low_length_;
RightPad right_pad_;
__host__ __device__ constexpr DynamicRightPad() = default; __host__ __device__ constexpr DynamicRightPad() = default;
__host__ __device__ constexpr DynamicRightPad(const index_t& low_length, __host__ __device__ constexpr DynamicRightPad(const LowLength& low_length,
const index_t& right_pad) const RightPad& right_pad)
: up_lengths_{make_multi_index(low_length + right_pad)}, : up_lengths_{make_tuple(low_length + right_pad)},
low_length_{low_length}, low_length_{low_length},
right_pad_{right_pad} right_pad_{right_pad}
{ {
...@@ -320,8 +335,10 @@ struct DynamicRightPad ...@@ -320,8 +335,10 @@ struct DynamicRightPad
{ {
printf("{"); printf("{");
printf("DynamicRightPad, "); printf("DynamicRightPad, ");
printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("left_pad_ %d", right_pad_); printf("low_length_ %d", index_t{low_length_});
printf("left_pad_ %d", index_t{right_pad_});
printf("}"); printf("}");
} }
}; };
...@@ -422,24 +439,29 @@ struct DynamicEmbed ...@@ -422,24 +439,29 @@ struct DynamicEmbed
} }
}; };
template <index_t NDimLow> template <index_t NDimLow, typename LowLengths = MultiIndex<NDimLow>>
struct DynamicMerge struct DynamicMerge
{ {
using LowerIndex = MultiIndex<NDimLow>; using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
LowerIndex low_lengths_; using LowLengthsScan = decltype(
LowerIndex low_lengths_scan_; container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{}));
UpperIndex up_lengths_;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge() = default; __host__ __device__ constexpr DynamicMerge() = default;
__host__ __device__ constexpr DynamicMerge(const LowerIndex& low_lengths) __host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{container_reverse_exclusive_scan( low_lengths_scan_{
low_lengths, math::multiplies<index_t>{}, index_t{1})}, container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
up_lengths_{make_multi_index( up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
container_reduce(low_lengths, math::multiplies<index_t>(), index_t{1}))}
{ {
static_assert(LowerIndex::Size() == NDimLow, "wrong!"); static_assert(LowerIndex::Size() == NDimLow, "wrong!");
} }
...@@ -1017,31 +1039,27 @@ struct DynamicUnMerge ...@@ -1017,31 +1039,27 @@ struct DynamicUnMerge
{ {
printf("{"); printf("{");
printf("DynamicUnMerge, "); printf("DynamicUnMerge, ");
printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
print_multi_index(up_lengths_scan_); print_multi_index(up_lengths_scan_);
printf("}"); printf("}");
} }
}; };
template <typename LowerIndex = index_t>
struct DynamicFreeze struct DynamicFreeze
{ {
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<0>;
LowerIndex low_idx_; LowerIndex low_idx_;
__host__ __device__ constexpr DynamicFreeze() = default; __host__ __device__ constexpr DynamicFreeze() = default;
__host__ __device__ constexpr DynamicFreeze(const index_t& low_idx) __host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
: low_idx_{make_multi_index(low_idx)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperIndex{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
...@@ -1081,7 +1099,11 @@ struct DynamicFreeze ...@@ -1081,7 +1099,11 @@ struct DynamicFreeze
return true; return true;
} }
__host__ __device__ void Print() const { printf("DynamicFreeze"); } __host__ __device__ void Print() const
{
printf("DynamicFreeze");
printf("low_idx_ %d", index_t{low_idx_});
}
}; };
} // namespace ck } // namespace ck
......
...@@ -118,6 +118,7 @@ enum InMemoryDataOperation ...@@ -118,6 +118,7 @@ enum InMemoryDataOperation
AtomicAdd AtomicAdd
}; };
// index type
using index_t = int32_t; using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2))); typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
......
...@@ -3,6 +3,19 @@ ...@@ -3,6 +3,19 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <typename T>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(const T& x)
{
using namespace ck;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = T::At(i);
return Number<tmp>{};
},
T::Size());
}
template <class T, template <class T,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
...@@ -27,11 +40,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -27,11 +40,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
...@@ -41,7 +49,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -41,7 +49,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
// assume packed tensor #if 1
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc = const auto wei_k_c_y_x_desc =
...@@ -53,6 +61,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -53,6 +61,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const auto conv_dilations = to_multi_index(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{}); const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
#else
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
#if 0 #if 0
// cdata = 64, BlockSize = 256, 128x128x2 // cdata = 64, BlockSize = 256, 128x128x2
...@@ -210,28 +231,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -210,28 +231,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
const index_t N = out_n_k_ho_wo_desc.GetLength(I0);
const index_t K = out_n_k_ho_wo_desc.GetLength(I1);
const index_t Ho = out_n_k_ho_wo_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_desc.GetLength(I3);
const index_t C = wei_k_c_y_x_desc.GetLength(I1);
const index_t Y = wei_k_c_y_x_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_desc.GetLength(I3);
const index_t GemmM = K;
const index_t GemmN = N * Ho * Wo;
const index_t GemmK = C * Y * X;
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment