"examples/vscode:/vscode.git/clone" did not exist on "40555b3a5e5ef8faefed69cb599fad73afdb9574"
Commit 0c15de6a authored by Adam Osewski's avatar Adam Osewski Committed by Sam Wu
Browse files

CK Tile - small fix to hotloop scheduler & KPack value. (#1867)

* Use SmemPack in HotLoop scheduler

* Additional debug print information

* Change KPack value.

Hardcode for now, as without AK1/BK1 there's no good way to determine
its value.

* Fix HotLoopScheduler MFMA instr parameters.
parent ab5d0278
...@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
arg.Print(); arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
} }
if(!GridwiseGemm::CheckValidity(arg)) if(!GridwiseGemm::CheckValidity(arg))
...@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: " << "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: " << "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst ...@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL); KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n", "%d, %d\n C MFMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num, A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num, A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num, A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num); C_MFMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
} }
}; };
......
...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ?? // TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread; // should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
}; };
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
#pragma once #pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
...@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64; constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL; // Below should be equal to AK1|BK1
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num =
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num = constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num = constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (BlockSize / WaveSize) /
......
...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment