Commit 59bd170d authored by root's avatar root
Browse files

rename tuning param

parent 07237cab
......@@ -20,13 +20,12 @@ template <index_t BlockSize,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
index_t ABlockTransferSrcScalarPerVector_E,
index_t ABlockTransferDstScalarPerVector_K,
index_t BThreadTransferSrcScalarPerVector_W,
index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{
template <typename... Wei,
......@@ -181,23 +180,22 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>,
3,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
BThreadTransferSrcScalarPerVector_W,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
CThreadTransferDstScalarPerVector_W,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_k_n_h_w_global_tensor_iterator_hacks),
......
......@@ -26,18 +26,17 @@ template <index_t BlockSize,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
......@@ -55,7 +54,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const auto K = 16;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<K>{});
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<K>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -122,7 +121,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{});
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -154,7 +153,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M>{};
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......@@ -176,8 +175,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<E, K>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
Float,
Float,
......@@ -188,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
ABlockTransferDstScalarPerVector_K,
AddressSpace::Global,
AddressSpace::Lds,
1,
......@@ -231,12 +230,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
......@@ -248,7 +247,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
......
......@@ -80,16 +80,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......@@ -104,13 +103,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_W,
CThreadTransferDstScalarPerVector_W>{};
conv_driver.Run(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment