Commit 758f576a authored by root's avatar root
Browse files

parameters clean

parent 6e59255a
...@@ -9,22 +9,17 @@ ...@@ -9,22 +9,17 @@
namespace ck { namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename Float,
typename AccFloat, typename AccFloat,
index_t GemmMPerBlock, index_t KPerBlock,
index_t GemmNPerBlock, index_t HPerBlock,
index_t GemmKPerBlock, index_t WPerBlock,
index_t GemmMPerThread, index_t CYXPerBlock,
index_t GemmNPerThread, index_t KPerThread,
index_t GemmKPerThread, index_t HPerThread,
index_t GemmMLevel0Cluster, index_t WPerThread,
index_t GemmNLevel0Cluster, index_t CYXPerThread,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK, index_t GemmABlockTransferSrcScalarPerVector_GemmK,
...@@ -130,12 +125,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -130,12 +125,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto GemmM = K; const auto CYX = C * Y * X;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && if(!(K % KPerBlock == 0 && Ho % HPerBlock == 0 && Wo % WPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) CYX % CYXPerBlock == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -182,16 +175,14 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -182,16 +175,14 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
GemmMPerBlock, KPerBlock,
GemmNPerBlock, HPerBlock,
GemmKPerBlock, WPerBlock,
GemmMPerThread, CYXPerBlock,
GemmNPerThread, KPerThread,
GemmKPerThread, HPerThread,
GemmMLevel0Cluster, WPerThread,
GemmNLevel0Cluster, CYXPerThread,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
...@@ -218,11 +209,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -218,11 +209,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const auto GridSize = (K / KPerBlock) * (Ho / HPerBlock) * (Wo / WPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; const bool has_main_k_block_loop = (CYX + CYXPerBlock) / (2 * CYXPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (CYX / CYXPerBlock) % 2 == 0;
index_t nrepeat = 100; index_t nrepeat = 100;
......
...@@ -19,15 +19,13 @@ template <index_t BlockSize, ...@@ -19,15 +19,13 @@ template <index_t BlockSize,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
index_t KPerBlock, index_t KPerBlock,
index_t HWPerBlock, index_t HPerBlock,
index_t WPerBlock,
index_t CYXPerBlock, index_t CYXPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HWPerThread, index_t HPerThread,
index_t WPerThread,
index_t CYXPerThread, index_t CYXPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -99,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -99,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 1
const auto m_block_work_num = K / Number<KPerBlock>{}; const auto m_block_work_num = K / Number<KPerBlock>{};
const auto nhw_block_work_num = (N * H * W) / Number<HWPerBlock>{}; const auto nhw_block_work_num = (N * H * W) / (Number<HPerBlock>{} * Number<WPerBlock>{});
const index_t k_block_work_id = get_block_1d_id() / nhw_block_work_num; const index_t k_block_work_id = get_block_1d_id() / nhw_block_work_num;
const index_t nhw_block_work_id = get_block_1d_id() - k_block_work_id * nhw_block_work_num; const index_t nhw_block_work_id = get_block_1d_id() - k_block_work_id * nhw_block_work_num;
...@@ -120,10 +118,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -120,10 +118,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t w_block_data_on_global = nhw_block_work_id * 8; const index_t w_block_data_on_global = nhw_block_work_id * 8;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align =
Number<BBlockTransferDstScalarPerVector_N>{}, math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{});
Number<KPerThread>{},
Number<HWPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
......
...@@ -70,18 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -70,18 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t GemmNPerBlock = 64; constexpr index_t HPerBlock = 8;
constexpr index_t GemmKPerBlock = 4 * 3 * 3; constexpr index_t WPerBlock = 8;
constexpr index_t CYXPerBlock = 4 * 3 * 3;
constexpr index_t GemmMPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t GemmNPerThread = 1; constexpr index_t HPerThread = 1;
constexpr index_t GemmKPerThread = 4; constexpr index_t WPerThread = 1;
constexpr index_t CYXPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 1;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 64;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
...@@ -102,16 +99,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -102,16 +99,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
BlockSize, BlockSize,
TDevice, TDevice,
TDevice, TDevice,
GemmMPerBlock, KPerBlock,
GemmNPerBlock, HPerBlock,
GemmKPerBlock, WPerBlock,
GemmMPerThread, CYXPerBlock,
GemmNPerThread, KPerThread,
GemmKPerThread, HPerThread,
GemmMLevel0Cluster, WPerThread,
GemmNLevel0Cluster, CYXPerThread,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment