Commit 666bdad1 authored by root's avatar root
Browse files

clean

parent f744524e
...@@ -13,13 +13,13 @@ template <index_t BlockSize, ...@@ -13,13 +13,13 @@ template <index_t BlockSize,
typename Float, typename Float,
typename AccFloat, typename AccFloat,
index_t KPerBlock, index_t KPerBlock,
index_t HPerBlock, index_t HoPerBlock,
index_t WPerBlock, index_t WoPerBlock,
index_t CYXPerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThread, index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK, index_t GemmABlockTransferSrcScalarPerVector_GemmK,
...@@ -123,10 +123,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -123,10 +123,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto CYX = C * Y * X; const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HPerBlock == 0 && Wo % WPerBlock == 0 && if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
CYX % CYXPerBlock == 0)) E % EPerBlock == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1 #if 1
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
...@@ -174,13 +174,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -174,13 +174,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
KPerBlock, KPerBlock,
HPerBlock, HoPerBlock,
WPerBlock, WoPerBlock,
CYXPerBlock, EPerBlock,
KPerThread, KPerThread,
HPerThread, HPerThread,
WPerThread, WPerThread,
CYXPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
...@@ -205,11 +205,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -205,11 +205,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HPerBlock) * (Wo / WPerBlock) * N; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const bool has_main_k_block_loop = (CYX + CYXPerBlock) / (2 * CYXPerBlock) > 1; const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (CYX / CYXPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100; index_t nrepeat = 100;
...@@ -225,6 +225,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -225,6 +225,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
#if 0
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
...@@ -251,7 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -251,7 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
#if 0 #else
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
......
...@@ -21,11 +21,11 @@ template <index_t BlockSize, ...@@ -21,11 +21,11 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t HPerBlock, index_t HPerBlock,
index_t WPerBlock, index_t WPerBlock,
index_t CYXPerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThread, index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -57,20 +57,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -57,20 +57,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYXPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e_k_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size) * sizeof(Float); return 2 * (a_block_space_size) * sizeof(Float);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
...@@ -83,15 +83,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -83,15 +83,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto CYX = a_cyx_k_global_desc.GetLength(I0); const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_cyx_k_global_desc.GetLength(I1); const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_cyx_n_h_w_global_desc.GetLength(I1); const auto N = b_e_n_h_w_global_desc.GetLength(I1);
const auto H = b_cyx_n_h_w_global_desc.GetLength(I2); const auto H = b_e_n_h_w_global_desc.GetLength(I2);
const auto W = b_cyx_n_h_w_global_desc.GetLength(I3); const auto W = b_e_n_h_w_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 0
const auto k_block_work_num = K / Number<KPerBlock>{}; const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{}; const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{}; const auto w_block_work_num = W / Number<WPerBlock>{};
...@@ -99,6 +99,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -99,6 +99,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num; const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num; const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
#else #else
// Hack: this force result into SGPR // Hack: this force result into SGPR
...@@ -110,10 +112,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -110,10 +112,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t k_block_work_id = const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hw_block_work_num); __builtin_amdgcn_readfirstlane(get_block_1d_id() / hw_block_work_num);
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num; const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
#endif const index_t h_block_work_id =
__builtin_amdgcn_readfirstlane(hw_block_work_id / w_block_work_num);
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num; const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
#endif
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align =
...@@ -121,14 +123,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -121,14 +123,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYXPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_cyx_n_h_w_block_desc = constexpr auto b_e_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_tuple(Number<EPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
...@@ -136,15 +137,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -136,15 +137,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, decltype(a_e_k_block_desc),
decltype(a_cyx_k_block_desc), decltype(b_e_n_h_w_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc), decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K 1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W 1 // ThreadGemmBDataPerRead_W
>{}; >{};
...@@ -166,14 +166,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -166,14 +166,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<CYXPerBlock, KPerBlock>, Sequence<EPerBlock, KPerBlock>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
decltype(a_cyx_k_global_desc), decltype(a_e_k_global_desc),
decltype(a_cyx_k_block_desc), decltype(a_e_k_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
...@@ -186,21 +186,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -186,21 +186,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_cyx_k_global_desc, a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global), make_multi_index(0, k_block_data_on_global),
a_cyx_k_block_desc, a_e_k_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
constexpr auto b_cyx_n_h_w_thread_desc = constexpr auto b_e_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float, Float,
Float, Float,
decltype(b_cyx_n_h_w_global_desc), decltype(b_e_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc), decltype(b_e_n_h_w_thread_desc),
Sequence<CYXPerBlock, 1, HPerThread, WPerThread>, Sequence<EPerBlock, 1, HPerThread, WPerThread>,
Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder, Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder,
3, // BBlockTransferSrcVectorDim, 3, // BBlockTransferSrcVectorDim,
1, // BBlockTransferSrcScalarPerVector, 1, // BBlockTransferSrcScalarPerVector,
...@@ -208,12 +208,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -208,12 +208,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(b_cyx_n_h_w_global_desc, true>(b_e_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global)); make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e_k_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
...@@ -223,37 +223,37 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -223,37 +223,37 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread); threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(CYXPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(EPerBlock, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_cyx_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_cyx_n_h_w_global_move_slice_window_iterator_hack = constexpr auto b_e_n_h_w_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); constexpr auto b_thread_space_size = b_e_n_h_w_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2]; Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread; Float* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double, p_b_thread_double,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_double);
} }
#if 1 #if 1
...@@ -272,89 +272,89 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -272,89 +272,89 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_odd, p_b_thread_odd,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_thread_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_thread_even, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_odd);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_even, p_b_thread_even,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_thread_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_thread_odd, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_even);
b_block_data_begin += 2 * CYXPerBlock; b_block_data_begin += 2 * EPerBlock;
} while(b_block_data_begin < CYX - 2 * CYXPerBlock); } while(b_block_data_begin < E - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size, p_b_thread_double + b_thread_space_size,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_double + a_block_space_size);
__syncthreads(); __syncthreads();
...@@ -410,9 +410,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -410,9 +410,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
...@@ -423,9 +423,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -423,9 +423,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__shared__ Float p_shared_block[shared_block_size]; __shared__ Float p_shared_block[shared_block_size];
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
...@@ -436,22 +436,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -436,22 +436,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by their pointers // pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_cyx_n_h_w_global_desc, const BGlobalDesc* p_b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_h_w_global_desc, const CGlobalDesc* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_cyx_k_global_desc = *p_a_cyx_k_global_desc; const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_cyx_n_h_w_global_desc = *p_b_cyx_n_h_w_global_desc; const auto b_e_n_h_w_global_desc = *p_b_e_n_h_w_global_desc;
const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc; const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
...@@ -461,25 +461,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -461,25 +461,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_cyx_k_global_desc, __device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const void* p_b_cyx_n_h_w_global_desc, const void* p_b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const void* p_c_k_n_h_w_global_desc, const void* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_cyx_k_global_desc = const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
*reinterpret_cast<const AGlobalDesc*>(p_a_cyx_k_global_desc); const auto b_e_n_h_w_global_desc =
const auto b_cyx_n_h_w_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_e_n_h_w_global_desc);
*reinterpret_cast<const BGlobalDesc*>(p_b_cyx_n_h_w_global_desc);
const auto c_k_n_h_w_global_desc = const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc); *reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
...@@ -498,11 +497,11 @@ template <index_t BlockSize, ...@@ -498,11 +497,11 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t HPerBlock, index_t HPerBlock,
index_t WPerBlock, index_t WPerBlock,
index_t CYXPerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThread, index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -529,7 +528,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -529,7 +528,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
const auto CYX = 4 * 3 * 3; const auto E = 4 * 3 * 3;
const auto K = 16; const auto K = 16;
constexpr auto max_lds_align = constexpr auto max_lds_align =
...@@ -537,20 +536,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -537,20 +536,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align); make_tuple(Number<E>{}, Number<K>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(Float); return a_block_space_size * sizeof(Float);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
...@@ -563,12 +562,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -563,12 +562,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto CYX = a_cyx_k_global_desc.GetLength(I0); const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_cyx_k_global_desc.GetLength(I1); const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_cyx_n_h_w_global_desc.GetLength(I1); const auto N = b_e_n_h_w_global_desc.GetLength(I1);
const auto H = b_cyx_n_h_w_global_desc.GetLength(I2); const auto H = b_e_n_h_w_global_desc.GetLength(I2);
const auto W = b_cyx_n_h_w_global_desc.GetLength(I3); const auto W = b_e_n_h_w_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 1
...@@ -601,17 +600,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -601,17 +600,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYXPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_cyx_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align); make_tuple(Number<E>{}, Number<K>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_cyx_n_h_w_block_desc = constexpr auto b_e_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_tuple(Number<EPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
...@@ -619,15 +617,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -619,15 +617,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, decltype(a_e_k_block_desc),
decltype(a_cyx_k_block_desc), decltype(b_e_n_h_w_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc), decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K 1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W 1 // ThreadGemmBDataPerRead_W
>{}; >{};
...@@ -646,17 +643,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -646,17 +643,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread; const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< auto a_blockwise_copy =
BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<CYX, K>, Sequence<E, K>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
decltype(a_cyx_k_global_desc), decltype(a_e_k_global_desc),
decltype(a_cyx_k_desc), decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
...@@ -668,21 +665,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -668,21 +665,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_cyx_k_global_desc, true>(
a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global), make_multi_index(0, k_block_data_on_global),
a_cyx_k_desc, a_e_k_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
constexpr auto b_cyx_n_h_w_thread_desc = constexpr auto b_e_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float, Float,
Float, Float,
decltype(b_cyx_n_h_w_global_desc), decltype(b_e_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc), decltype(b_e_n_h_w_thread_desc),
Sequence<CYXPerBlock, 1, HPerThread, WPerThread>, Sequence<EPerBlock, 1, HPerThread, WPerThread>,
Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder, Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder,
3, // BBlockTransferSrcVectorDim, 3, // BBlockTransferSrcVectorDim,
1, // BBlockTransferSrcScalarPerVector, 1, // BBlockTransferSrcScalarPerVector,
...@@ -690,7 +688,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -690,7 +688,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(b_cyx_n_h_w_global_desc, true>(b_e_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global)); make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
Float* p_a_block = p_shared_block; Float* p_a_block = p_shared_block;
...@@ -701,36 +699,36 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -701,36 +699,36 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread); threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread);
constexpr auto b_thread_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_cyx_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_cyx_n_h_w_global_move_slice_window_iterator_hack = constexpr auto b_e_n_h_w_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); constexpr auto b_thread_space_size = b_e_n_h_w_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2]; Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread; Float* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double, p_b_thread_double,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_cyx_k_desc, p_a_block); a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
} }
__syncthreads(); __syncthreads();
...@@ -748,69 +746,69 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -748,69 +746,69 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_odd, p_b_thread_odd,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block + a_cyx_k_block_desc.CalculateOffset( blockwise_gemm.Run(
make_tuple(b_block_data_begin, 0)), p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_even, p_b_thread_even,
p_c_thread); p_c_thread);
b_block_data_begin += CYXPerBlock; b_block_data_begin += EPerBlock;
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_even, p_b_thread_even,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block + a_cyx_k_block_desc.CalculateOffset( blockwise_gemm.Run(
make_tuple(b_block_data_begin, 0)), p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_odd, p_b_thread_odd,
p_c_thread); p_c_thread);
b_block_data_begin += CYXPerBlock; b_block_data_begin += EPerBlock;
} while(b_block_data_begin < CYX - 2 * CYXPerBlock); } while(b_block_data_begin < E - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc, b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
p_b_global, p_b_global,
b_cyx_n_h_w_thread_desc, b_e_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size, p_b_thread_double + b_thread_space_size,
b_cyx_n_h_w_global_iterator_hacks); b_e_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
p_a_block + a_cyx_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double, p_b_thread_double,
p_c_thread); p_c_thread);
b_block_data_begin += CYXPerBlock; b_block_data_begin += EPerBlock;
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
p_a_block + a_cyx_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double + b_thread_space_size, p_b_thread_double + b_thread_space_size,
p_c_thread); p_c_thread);
} }
...@@ -818,7 +816,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -818,7 +816,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{ {
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
p_a_block + a_cyx_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double, p_b_thread_double,
p_c_thread); p_c_thread);
} }
...@@ -862,9 +860,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -862,9 +860,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
...@@ -875,9 +873,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -875,9 +873,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__shared__ Float p_shared_block[shared_block_size]; __shared__ Float p_shared_block[shared_block_size];
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
...@@ -888,22 +886,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -888,22 +886,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by their pointers // pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_cyx_k_global_desc, __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_cyx_n_h_w_global_desc, const BGlobalDesc* p_b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_h_w_global_desc, const CGlobalDesc* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_cyx_k_global_desc = *p_a_cyx_k_global_desc; const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_cyx_n_h_w_global_desc = *p_b_cyx_n_h_w_global_desc; const auto b_e_n_h_w_global_desc = *p_b_e_n_h_w_global_desc;
const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc; const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
...@@ -913,25 +911,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -913,25 +911,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_cyx_k_global_desc, __device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const void* p_b_cyx_n_h_w_global_desc, const void* p_b_e_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const void* p_c_k_n_h_w_global_desc, const void* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_cyx_k_global_desc = const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
*reinterpret_cast<const AGlobalDesc*>(p_a_cyx_k_global_desc); const auto b_e_n_h_w_global_desc =
const auto b_cyx_n_h_w_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_e_n_h_w_global_desc);
*reinterpret_cast<const BGlobalDesc*>(p_b_cyx_n_h_w_global_desc);
const auto c_k_n_h_w_global_desc = const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc); *reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
Run(a_cyx_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_e_n_h_w_global_desc,
p_b_global, p_b_global,
c_k_n_h_w_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
......
...@@ -57,10 +57,10 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -57,10 +57,10 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto H = BDesc{}.GetLength(I2); constexpr auto H = BDesc{}.GetLength(I2);
constexpr auto W = BDesc{}.GetLength(I3); constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto CYX = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, CYX, 1>{}([&](auto e) { static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) { static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) { static_for<0, W, 1>{}([&](auto w) {
......
...@@ -73,15 +73,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -73,15 +73,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16; constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 16; constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 2 * 3 * 3; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2; constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 2; constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 2; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<9, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment