Commit 59bd170d authored by root's avatar root
Browse files

rename tuning param

parent 07237cab
...@@ -20,13 +20,12 @@ template <index_t BlockSize, ...@@ -20,13 +20,12 @@ template <index_t BlockSize,
index_t HoPerThread, index_t HoPerThread,
index_t WoPerThread, index_t WoPerThread,
index_t EPerThread, index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename ABlockTransferThreadSliceLengths_E_K,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, typename ABlockTransferThreadClusterLengths_E_K,
index_t GemmABlockTransferSrcScalarPerVector_GemmK, index_t ABlockTransferSrcScalarPerVector_E,
index_t GemmABlockTransferDstScalarPerVector_GemmM, index_t ABlockTransferDstScalarPerVector_K,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, index_t BThreadTransferSrcScalarPerVector_W,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, index_t CThreadTransferDstScalarPerVector_W>
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, template <typename... Wei,
...@@ -181,23 +180,22 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -181,23 +180,22 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, ABlockTransferThreadSliceLengths_E_K,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
0, 0,
GemmABlockTransferSrcScalarPerVector_GemmK, ABlockTransferSrcScalarPerVector_E,
GemmABlockTransferDstScalarPerVector_GemmM, ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>, Sequence<0, 2, 3, 1>,
3, 3,
GemmBBlockTransferSrcScalarPerVector_GemmN, BThreadTransferSrcScalarPerVector_W,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>, Sequence<0, 2, 3, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1, CThreadTransferDstScalarPerVector_W,
decltype(a_k_m_global_iterator_hacks), decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks), decltype(b_k_n_global_iterator_hacks),
decltype(c_k_n_h_w_global_tensor_iterator_hacks), decltype(c_k_n_h_w_global_tensor_iterator_hacks),
......
...@@ -26,18 +26,17 @@ template <index_t BlockSize, ...@@ -26,18 +26,17 @@ template <index_t BlockSize,
index_t HoPerThread, index_t HoPerThread,
index_t WoPerThread, index_t WoPerThread,
index_t EPerThread, index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M, index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
...@@ -55,7 +54,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -55,7 +54,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const auto K = 16; const auto K = 16;
constexpr auto max_lds_align = constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<K>{}); math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<K>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -122,7 +121,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -122,7 +121,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{}); math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -154,7 +153,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -154,7 +153,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M>{}; ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -176,8 +175,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -176,8 +175,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<E, K>, Sequence<E, K>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
...@@ -188,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -188,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_K,
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Lds, AddressSpace::Lds,
1, 1,
...@@ -231,12 +230,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -231,12 +230,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_e_k_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, p_b_global,
......
...@@ -80,16 +80,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -80,16 +80,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 2; constexpr index_t EPerThread = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
...@@ -104,13 +103,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -104,13 +103,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, ABlockTransferThreadSliceLengths_E_K,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, ABlockTransferThreadClusterLengths_E_K,
GemmABlockTransferSrcScalarPerVector_GemmK, ABlockTransferSrcScalarPerVector_E,
GemmABlockTransferDstScalarPerVector_GemmM, ABlockTransferDstScalarPerVector_K,
GemmBBlockTransferSrcScalarPerVector_GemmN, BThreadTransferSrcScalarPerVector_W,
GemmBBlockTransferDstScalarPerVector_GemmN, CThreadTransferDstScalarPerVector_W>{};
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
conv_driver.Run(wei_k_c_y_x_desc, conv_driver.Run(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc, in_n_c_hi_wi_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment