Commit b66c4cc7 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use GridwiseGemmPipeline

parent 3948d09b
...@@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
static constexpr PipelineVersion PipelineVer = PipelineVersion::v2;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec, GemmSpec,
NumGemmKPrefetchStage,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -112,7 +118,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -112,7 +118,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CShuffleMRepeatPerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle, CShuffleNRepeatPerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
......
...@@ -85,7 +85,7 @@ template <typename ALayout, ...@@ -85,7 +85,7 @@ template <typename ALayout,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch, ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -152,6 +152,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -152,6 +152,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
GemmSpec, GemmSpec,
NumGemmKPrefetchStage,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -179,7 +180,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -179,7 +180,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVersion::v2>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
......
...@@ -55,6 +55,7 @@ template <index_t BlockSize, ...@@ -55,6 +55,7 @@ template <index_t BlockSize,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec, tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -82,7 +83,9 @@ template <index_t BlockSize, ...@@ -82,7 +83,9 @@ template <index_t BlockSize,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock> typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -101,6 +104,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -101,6 +104,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
const FloatAB* p_a_grid; const FloatAB* p_a_grid;
...@@ -497,6 +503,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -497,6 +503,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
} }
const auto num_k_loop = karg.K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
return true; return true;
} }
...@@ -509,9 +527,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -509,9 +527,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = K0 > K0PerBlock; const index_t num_loop = K0 / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
return has_main_k0_block_loop;
} }
template <typename CGridDesc> template <typename CGridDesc>
...@@ -750,20 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -750,20 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
#if 1
auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
#else
auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -773,9 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -773,9 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
K1>{}; K1,
LoopSched>();
#endif
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -831,7 +835,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -831,7 +835,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock; k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock)); } while(k0_block_data_begin < (karg.K0 - K0PerBlock));
} }
// tail // tail
...@@ -841,14 +845,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -841,14 +845,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
#else #else
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
(K0PerBlock * K1)); (K0PerBlock * K1));
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc, a_b_k0_m_k1_block_desc,
a_blockwise_copy, a_blockwise_copy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment