Commit 18707866 authored by Chao Liu's avatar Chao Liu
Browse files

adding thread group

parent ee33b1fa
......@@ -27,7 +27,8 @@ template <typename ALayout,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t ABBlockTransferThreadGroupSize,
index_t BlockGemmThreadGroupSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
ABBlockTransferThreadGroupSize,
BlockGemmThreadGroupSize,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
......@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
}
else
......@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
......@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
}
......@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
// clang-format off
str << "DeviceGemm_Xdl_CShuffle_v2"
<< "<"
<< BlockSize << ", "
<< ABBlockTransferThreadGroupSize << ", "
<< BlockGemmThreadGroupSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
......
......@@ -4,7 +4,9 @@
namespace ck {
template <typename AGridDesc,
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
......@@ -23,7 +25,9 @@ template <typename AGridDesc,
struct GridwiseGemmPipeline_v2;
// 1-stage prefetch
template <typename AGridDesc,
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
......@@ -38,7 +42,9 @@ template <typename AGridDesc,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v2<AGridDesc,
struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
......@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static __device__ void RunProducer(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
__device__ constexpr GridwiseGemmPipeline_v2()
{
// TODO static assert
}
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
}
}
static __device__ void RunConsumer(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
......@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunBlockGemmPipeline(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
}
}
};
} // namespace ck
......@@ -67,7 +67,8 @@ template <typename FloatAB,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t ABBlockTransferThreadGroupSize,
index_t BlockGemmThreadGroupSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock =
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 1
using ABBlockTransferThreadGroup = ThisThreadBlock;
using BlockGemmThreadGroup = ThisThreadBlock;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#else
struct ABBlockTransferThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return ABBlockTransferThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() < ABBlockTransferThreadGroupSize;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
struct BlockGemmThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return ABBlockTransferThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() >= ABBlockTransferThreadGroupSize;
}
__device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - ABBlockTransferThreadGroupSize;
}
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
using ThisThreadBlock = AnyThreadBlock<BlockSize>;
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockGemmThreadGroup,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
......@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
#if 1
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
......@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
#else
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
#endif
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
......@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
......@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
if(BlockGemmThreadGroup::IsBelong())
{
// thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
}
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if(CShuffleBlockTransferThreadGroup::IsBelong())
{
// block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(access_id < num_access - 1)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment