Commit 5b242405 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f1403dac
......@@ -17,8 +17,8 @@ template <index_t BlockSize,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
......@@ -178,8 +178,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
WoPerBlock,
EPerBlock,
KPerThread,
HPerThread,
WPerThread,
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
......
......@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs>
__global__ void
#if 0
__launch_bounds__(256, 2)
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
run_gridwise_operation(Xs... xs)
run_gridwise_operation(Xs... xs)
{
GridwiseOp{}.Run(xs...);
}
......
......@@ -19,12 +19,12 @@ template <index_t BlockSize,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
......@@ -69,9 +69,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_h_w_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -85,35 +85,35 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_h_w_global_desc.GetLength(I1);
const auto H = b_e_n_h_w_global_desc.GetLength(I2);
const auto W = b_e_n_h_w_global_desc.GetLength(I3);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num;
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t h_block_work_num = __builtin_amdgcn_readfirstlane(H / HPerBlock);
const index_t w_block_work_num = __builtin_amdgcn_readfirstlane(W / WPerBlock);
const index_t hw_block_work_num = h_block_work_num * w_block_work_num;
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hw_block_work_num);
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id =
__builtin_amdgcn_readfirstlane(hw_block_work_id / w_block_work_num);
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
......@@ -127,39 +127,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
constexpr auto b_e_n_ho_wo_block_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
constexpr auto c_k_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread, // KPerThreadSubC
HoPerThread, // HoPerThreadSubC
WoPerThread, // WoPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto h_thread_id = c_thread_mtx_index.h;
const auto w_thread_id = c_thread_mtx_index.w;
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -190,26 +194,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_e_k_block_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_h_w_thread_desc =
constexpr auto b_e_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
auto b_threadwise_transfer =
ThreadwiseDynamicTensorSliceTransfer_v2<Float,
Float,
decltype(b_e_n_h_w_global_desc),
decltype(b_e_n_h_w_thread_desc),
Sequence<EPerBlock, 1, HPerThread, WPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(
b_e_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float,
Float,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -218,26 +221,26 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float* p_a_block_double = p_shared_block;
// register allocation for output
AccFloat p_c_thread[c_k_n_h_w_thread_desc.GetElementSpaceSize()];
AccFloat p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread);
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(EPerBlock, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_h_w_global_move_slice_window_iterator_hack =
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_h_w_thread_desc.GetElementSpaceSize();
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread;
......@@ -246,12 +249,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_double);
}
......@@ -276,7 +279,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
......@@ -285,12 +288,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.RunRead(
a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_odd,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_thread_even, p_c_thread);
......@@ -303,7 +306,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
......@@ -311,12 +314,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.RunRead(
a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_even,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_thread_odd, p_c_thread);
......@@ -335,7 +338,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
......@@ -343,12 +346,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
......@@ -375,8 +378,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
......@@ -384,9 +387,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
decltype(c_k_n_h_w_thread_desc),
decltype(c_k_n_h_w_global_desc),
Sequence<KPerThread, 1, HPerThread, WPerThread>,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
......@@ -395,15 +398,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global))
.Run(c_k_n_h_w_thread_desc,
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
c_k_n_h_w_global_tensor_iterator_hacks);
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
......@@ -412,9 +415,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_h_w_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
......@@ -425,9 +428,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
......@@ -438,22 +441,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_h_w_global_desc,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_h_w_global_desc,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_h_w_global_desc = *p_b_e_n_h_w_global_desc;
const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -463,24 +466,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const void* p_b_e_n_h_w_global_desc,
const void* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const void* p_c_k_n_h_w_global_desc,
const void* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_h_w_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_h_w_global_desc);
const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -495,12 +498,12 @@ template <index_t BlockSize,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
......@@ -548,9 +551,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_h_w_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -564,34 +567,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_h_w_global_desc.GetLength(I1);
const auto H = b_e_n_h_w_global_desc.GetLength(I2);
const auto W = b_e_n_h_w_global_desc.GetLength(I3);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 1
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num;
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t h_block_work_num = __builtin_amdgcn_readfirstlane(H / HPerBlock);
const index_t w_block_work_num = __builtin_amdgcn_readfirstlane(W / WPerBlock);
const index_t hw_block_work_num = h_block_work_num * w_block_work_num;
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hw_block_work_num);
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
#endif
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
// lds max alignment
constexpr auto max_lds_align =
......@@ -607,39 +610,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
constexpr auto b_e_n_ho_wo_block_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
constexpr auto c_k_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread, // KPerThreadSubC
HoPerThread, // HoPerThreadSubC
WoPerThread, // WoPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto h_thread_id = c_thread_mtx_index.h;
const auto w_thread_id = c_thread_mtx_index.w;
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -670,16 +677,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
a_e_k_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_h_w_thread_desc =
constexpr auto b_e_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float,
Float,
decltype(b_e_n_h_w_global_desc),
decltype(b_e_n_h_w_thread_desc),
Sequence<EPerBlock, 1, HPerThread, WPerThread>,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder,
3, // BBlockTransferSrcVectorDim,
1, // BBlockTransferSrcScalarPerVector,
......@@ -687,31 +694,31 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(b_e_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
Float* p_a_block = p_shared_block;
// register allocation for output
AccFloat p_c_thread[c_k_n_h_w_thread_desc.GetElementSpaceSize()];
AccFloat p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_h_w_thread_desc, p_c_thread);
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_h_w_global_move_slice_window_iterator_hack =
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_h_w_thread_desc.GetElementSpaceSize();
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread;
......@@ -720,12 +727,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
}
......@@ -745,15 +752,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_odd,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
......@@ -763,15 +770,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_block_data_begin += EPerBlock;
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_even,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
......@@ -787,15 +794,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_h_w_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_h_w_global_desc,
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_h_w_thread_desc,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size,
b_e_n_h_w_global_iterator_hacks);
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
......@@ -824,8 +831,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
......@@ -833,9 +840,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
decltype(c_k_n_h_w_thread_desc),
decltype(c_k_n_h_w_global_desc),
Sequence<KPerThread, 1, HPerThread, WPerThread>,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder
3, // CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector,
......@@ -844,15 +851,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global))
.Run(c_k_n_h_w_thread_desc,
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
c_k_n_h_w_global_tensor_iterator_hacks);
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
......@@ -861,9 +868,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_h_w_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_h_w_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
......@@ -874,9 +881,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
......@@ -887,22 +894,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_h_w_global_desc,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_h_w_global_desc,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_h_w_global_desc = *p_b_e_n_h_w_global_desc;
const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -912,24 +919,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const void* p_b_e_n_h_w_global_desc,
const void* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const void* p_c_k_n_h_w_global_desc,
const void* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_h_w_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_h_w_global_desc);
const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_h_w_global_desc,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_h_w_global_desc,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
......@@ -7,6 +7,10 @@
#endif
#include "bfloat16_dev.hpp"
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if 1
#define CK_AMD_GPU_GFX906 1
#elif 0
......@@ -15,22 +19,29 @@
#define CK_AMD_GPU_GFX1030 1
#endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1
#endif
// buffer resourse
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 1
......
......@@ -133,6 +133,39 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
......
......@@ -70,15 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
......@@ -97,13 +97,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
TDevice,
TDevice,
KPerBlock,
HPerBlock,
WPerBlock,
CYXPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HPerThread,
WPerThread,
CYXPerThread,
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
......
......@@ -34,8 +34,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
......@@ -736,7 +736,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 1
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment