"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "7d101f83711703b27996a3c6fc64dd6cb101ec7d"
Unverified Commit d2afb513 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Introduce GemmInst for different targets handling (#688)

* [Enhancement] Refactor GEMM operations for improved warp partitioning and target instruction handling

- Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture.
- Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic.
- Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture.
- Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase.

* bug fix

* test fix

* lint fix
parent 73bf8346
...@@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
} }
} }
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
bool maybe_hopper_wgmma) const { int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
int num_warps = block_size / TargetGetWarpSize(target);
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
ICHECK(this->M % kMPerWarp == 0) ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M; << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0) ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N; << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (allow_wgmma) { if (gemm_inst == GemmInst::kWGMMA) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads."; ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group constexpr int kGroup = 4; // Number of warps in a warp-group
...@@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const { ...@@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const {
} }
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && GemmInst gemm_inst = GetGemmInst(block_size, T.target);
(block_size / warp_size % 4 == 0) && CheckWGMMA(); auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_ss"; std::string op_name = "tl::gemm_ss";
...@@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// for cdna gemm, we need to specify kPack // for cdna gemm, we need to specify kPack
ss << ", " << kPack; ss << ", " << kPack;
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false"); ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
} }
if (wg_wait != 0) { if (wg_wait != 0) {
ss << ", " << wg_wait; ss << ", " << wg_wait;
...@@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds; auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent); auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
...@@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
*as_const_int(B->shape[dim_B - 1]), *as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1)); false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) { } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
...@@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0); ICHECK(0);
} }
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma =
(this->M >= 64) && (block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment = auto fragment =
maybe_wgmma gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits()) C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
...@@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity = const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous; trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
auto ABLayout = auto ABLayout =
maybe_wgmma gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2) A->dtype.bits(), trans_A ? 1 : 2)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
...@@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity = const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n; trans_B ? mat_continuous : mat_continuous / warp_n;
auto ABLayout = auto ABLayout =
maybe_wgmma gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1) B->dtype.bits(), trans_B ? 2 : 1)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
...@@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0) << "WGMMA only support B in shared."; ICHECK(0) << "WGMMA only support B in shared.";
} }
} else if (TargetIsCDNA(T.target)) { } else if (TargetIsCDNA(T.target)) {
const int warp_size = 64;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
......
...@@ -27,9 +27,12 @@ public: ...@@ -27,9 +27,12 @@ public:
} policy; } policy;
private: private:
std::pair<int, int> // Target GEMM instruction
ComputeWarpPartition(int num_warps, Target target, enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
bool maybe_hopper_wgmma = true) const; GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
bool CheckWGMMA() const; bool CheckWGMMA() const;
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
......
...@@ -97,5 +97,12 @@ bool TargetHasStmatrix(Target target) { ...@@ -97,5 +97,12 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90; return arch >= 90;
} }
int TargetGetWarpSize(Target target) {
int res = 32;
if (TargetIsCDNA(target))
res = 64;
return res;
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -24,6 +24,7 @@ bool TargetIsCDNA(Target target); ...@@ -24,6 +24,7 @@ bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target); bool TargetHasStmatrix(Target target);
int TargetGetWarpSize(Target target);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment