Commit 18707866 authored by Chao Liu's avatar Chao Liu
Browse files

adding thread group

parent ee33b1fa
...@@ -27,7 +27,8 @@ template <typename ALayout, ...@@ -27,7 +27,8 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t ABBlockTransferThreadGroupSize,
index_t BlockGemmThreadGroupSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, ABBlockTransferThreadGroupSize,
BlockGemmThreadGroupSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{ {
launch_kernel(kernel, launch_kernel(kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
...@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
} }
else else
{ {
ave_time = ave_time = launch_and_time_kernel(
launch_and_time_kernel(kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
} }
else else
...@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{ {
launch_kernel(kernel, launch_kernel(kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
...@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
} }
else else
{ {
ave_time = ave_time = launch_and_time_kernel(
launch_and_time_kernel(kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
} }
...@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle_v2" str << "DeviceGemm_Xdl_CShuffle_v2"
<< "<" << "<"
<< BlockSize << ", " << ABBlockTransferThreadGroupSize << ", "
<< BlockGemmThreadGroupSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
namespace ck { namespace ck {
template <typename AGridDesc, template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
typename AGridDesc,
typename ABlockDesc, typename ABlockDesc,
typename ABlockTransfer, typename ABlockTransfer,
typename AGridBuffer, typename AGridBuffer,
...@@ -23,7 +25,9 @@ template <typename AGridDesc, ...@@ -23,7 +25,9 @@ template <typename AGridDesc,
struct GridwiseGemmPipeline_v2; struct GridwiseGemmPipeline_v2;
// 1-stage prefetch // 1-stage prefetch
template <typename AGridDesc, template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
typename AGridDesc,
typename ABlockDesc, typename ABlockDesc,
typename ABlockTransfer, typename ABlockTransfer,
typename AGridBuffer, typename AGridBuffer,
...@@ -38,7 +42,9 @@ template <typename AGridDesc, ...@@ -38,7 +42,9 @@ template <typename AGridDesc,
typename BlockwiseGemm, typename BlockwiseGemm,
typename CThreadBuffer, typename CThreadBuffer,
bool HasMainLoop> bool HasMainLoop>
struct GridwiseGemmPipeline_v2<AGridDesc, struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
AGridDesc,
ABlockDesc, ABlockDesc,
ABlockTransfer, ABlockTransfer,
AGridBuffer, AGridBuffer,
...@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc, ...@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static __device__ void RunProducer(const AGridDesc& a_grid_desc, __device__ constexpr GridwiseGemmPipeline_v2()
const ABlockDesc& a_block_desc, {
ABlockTransfer& a_blockwise_copy, // TODO static assert
const AGridBuffer& a_grid_buf, }
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const ABlockDesc& a_block_desc,
const BBlockDesc& b_block_desc, ABlockTransfer& a_blockwise_copy,
BBlockTransfer& b_blockwise_copy, const AGridBuffer& a_grid_buf,
const BGridBuffer& b_grid_buf, ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, const ABlockTransferStep& a_block_copy_step,
const BBlockTransferStep& b_block_copy_step, const BGridDesc& b_grid_desc,
index_t num_loop) const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc, ...@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
} }
} }
static __device__ void RunConsumer(ABlockBuffer& a_block_buf, static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc, ...@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunBlockGemmPipeline(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
}
}
}; };
} // namespace ck } // namespace ck
...@@ -67,7 +67,8 @@ template <typename FloatAB, ...@@ -67,7 +67,8 @@ template <typename FloatAB,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t ABBlockTransferThreadGroupSize,
index_t BlockGemmThreadGroupSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock =
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 1
using ABBlockTransferThreadGroup = ThisThreadBlock;
using BlockGemmThreadGroup = ThisThreadBlock;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#else
struct ABBlockTransferThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return ABBlockTransferThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() < ABBlockTransferThreadGroupSize;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
struct BlockGemmThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return ABBlockTransferThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() >= ABBlockTransferThreadGroupSize;
}
__device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - ABBlockTransferThreadGroupSize;
}
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
using ThisThreadBlock = AnyThreadBlock<BlockSize>;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockGemmThreadGroup,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
#if 1
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>, GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
...@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
remove_cvref_t<decltype(c_thread_buf)>, remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
HasMainK0BlockLoop>{}; HasMainK0BlockLoop>{};
#else
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( // gridwise GEMM pipeline
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / const auto gridwise_gemm_pipeline =
KPerBlock); GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
#endif
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
...@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize, ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
...@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS if(BlockGemmThreadGroup::IsBelong())
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, {
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // thread write its data from VGPR to LDS
c_thread_buf, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_shuffle_block_buf); c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
}
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global if(CShuffleBlockTransferThreadGroup::IsBelong())
c_shuffle_block_copy_lds_to_global.Run( {
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, // block copy its data from LDS to global
c_shuffle_block_buf, c_shuffle_block_copy_lds_to_global.Run(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment