Unverified Commit d2afb513 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Introduce GemmInst for different targets handling (#688)

* [Enhancement] Refactor GEMM operations for improved warp partitioning and target instruction handling

- Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture.
- Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic.
- Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture.
- Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase.

* bug fix

* test fix

* lint fix
parent 73bf8346
......@@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
}
}
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
int num_warps = block_size / TargetGetWarpSize(target);
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (allow_wgmma) {
if (gemm_inst == GemmInst::kWGMMA) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group
......@@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const {
}
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
std::stringstream ss;
std::string op_name = "tl::gemm_ss";
......@@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
} else if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
......@@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......@@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......@@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0);
}
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma =
(this->M >= 64) && (block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
......@@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
auto ABLayout =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
......@@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
auto ABLayout =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
......@@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0) << "WGMMA only support B in shared.";
}
} else if (TargetIsCDNA(T.target)) {
const int warp_size = 64;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......
......@@ -27,9 +27,12 @@ public:
} policy;
private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
bool CheckWGMMA() const;
Array<PrimExpr> call_args;
......
......@@ -97,5 +97,12 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90;
}
int TargetGetWarpSize(Target target) {
int res = 32;
if (TargetIsCDNA(target))
res = 64;
return res;
}
} // namespace tl
} // namespace tvm
......@@ -24,6 +24,7 @@ bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target);
int TargetGetWarpSize(Target target);
} // namespace tl
} // namespace tvm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment