Commit 758f576a authored by root's avatar root
Browse files

parameters clean

parent 6e59255a
......@@ -9,22 +9,17 @@
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t KPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t CYXPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t CYXPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
......@@ -130,12 +125,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto CYX = C * Y * X;
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
if(!(K % KPerBlock == 0 && Ho % HPerBlock == 0 && Wo % WPerBlock == 0 &&
CYX % CYXPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
......@@ -182,16 +175,14 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
KPerBlock,
HPerBlock,
WPerBlock,
CYXPerBlock,
KPerThread,
HPerThread,
WPerThread,
CYXPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
......@@ -218,11 +209,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto GridSize = (K / KPerBlock) * (Ho / HPerBlock) * (Wo / WPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_main_k_block_loop = (CYX + CYXPerBlock) / (2 * CYXPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = (CYX / CYXPerBlock) % 2 == 0;
index_t nrepeat = 100;
......
......@@ -19,15 +19,13 @@ template <index_t BlockSize,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HWPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t CYXPerBlock,
index_t KPerThread,
index_t HWPerThread,
index_t HPerThread,
index_t WPerThread,
index_t CYXPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -99,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N]
#if 1
const auto m_block_work_num = K / Number<KPerBlock>{};
const auto nhw_block_work_num = (N * H * W) / Number<HWPerBlock>{};
const auto nhw_block_work_num = (N * H * W) / (Number<HPerBlock>{} * Number<WPerBlock>{});
const index_t k_block_work_id = get_block_1d_id() / nhw_block_work_num;
const index_t nhw_block_work_id = get_block_1d_id() - k_block_work_id * nhw_block_work_num;
......@@ -120,10 +118,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t w_block_data_on_global = nhw_block_work_id * 8;
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<KPerThread>{},
Number<HWPerThread>{});
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......
......@@ -70,18 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4 * 3 * 3;
constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8;
constexpr index_t WPerBlock = 8;
constexpr index_t CYXPerBlock = 4 * 3 * 3;
constexpr index_t GemmMPerThread = 16;
constexpr index_t GemmNPerThread = 1;
constexpr index_t GemmKPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 1;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 64;
constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1;
constexpr index_t WPerThread = 1;
constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
......@@ -102,16 +99,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
BlockSize,
TDevice,
TDevice,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
KPerBlock,
HPerBlock,
WPerBlock,
CYXPerBlock,
KPerThread,
HPerThread,
WPerThread,
CYXPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment