gemm.cc 32 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14

#include "../target/utils.h"
15
#include "tcgen5_meta.h"
16
#include "utils.h"
17
18
19
20
21
22

namespace tvm {
namespace tl {

using namespace tir;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 *
45
 * @note If `kPack` is provided it must be 1; otherwise the constructor
46
47
48
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
49
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
50

51
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
52

53
Gemm::Gemm(Array<PrimExpr> args) {
54
  ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
55

56
57
58
  node->aRegion_ = NormalizeToBufferRegion(args[0]);
  node->bRegion_ = NormalizeToBufferRegion(args[1]);
  node->cRegion_ = NormalizeToBufferRegion(args[2]);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
74
  if (args.size() > 14) {
75
76
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
77
78
79
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
80
  if (args.size() > 15) {
81
    node->wgWait_ = args[15].as<IntImm>().value()->value;
82
  }
83
84
85
86
87
88
89
90
  if (args.size() > 16) {
    if (const auto *load = args[16].as<BufferLoadNode>()) {
      node->mbarRegion_ =
          NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
      node->mbar_ = node->mbarRegion_->buffer;
    } else {
      node->mbar_ = std::nullopt;
    }
91
  }
92
  node->cCoords_ = Array<PrimExpr>(
93
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
94
  data_ = std::move(node);
95
96
}

97
98
99
100
101
102
103
104
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
105
TileOperator GemmNode::Clone() const {
106
  auto op = tvm::ffi::make_object<GemmNode>(*this);
107
108
109
  return Gemm(op);
}

110
bool GemmNode::allowTcgen5Mma(Target target) const {
111
  return TargetIsSm100(target) &&
112
113
114
115
116
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
117
118
}

119
bool GemmNode::allowWgmma(int block_size, Target target) const {
120
121
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

122
123
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
124
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
125
126
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
127
128
}

129
130
GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
  if (allowTcgen5Mma(target)) {
131
    return GemmInst::kTCGEN5MMA;
132
  } else if (allowWgmma(block_size, target)) {
133
134
135
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
136
  } else if (TargetIsCuda(target)) {
137
138
    return GemmInst::kMMA;
  } else {
139
    ICHECK(0) << "Unsupported target for gemm: " << target;
140
    return GemmInst::kMMA;
141
142
143
  }
}

144
std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
145
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
146
  int num_warps = block_size / TargetGetWarpSize(target);
147
148
149
150
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

151
  int m_warp = 1, n_warp = 1;
152
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
153
154
155
156
  int kNPerWarp = 8;            // Columns processed by a single warp
  if (TargetIsVolta(target)) {
    kNPerWarp = 16;
  }
157
158
159
160
161
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

162
  if (gemm_inst == GemmInst::kWGMMA) {
163
164
165
166
167
168
169
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

170
    if (this->isFullRow()) {
171
172
173
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
174
        if (M % (cand * kMPerWarp) == 0) {
175
176
177
178
179
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
180
    } else if (this->isFullCol()) {
181
182
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
183
184
185
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
186
187
188
189
190
191
192
193
194
195
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
196
    } else if (this->isSquare()) {
197
      // Exhaustive search, but m must be multiple of 4
198
199
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
200

201
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
202
203
204
205
206
207
208
209
210
211
212

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

213
214
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
215
216
217
218
219
220
221
222
223
224
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
225
226
227
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
228
229

    ICHECK(m_warp * n_warp == num_warps)
230
231
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
232
233
234
235
236

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

237
238
    return {m_warp, n_warp};
  }
239

240
  if (this->isFullRow()) {
241
    // Try to partition M first
242
    m_warp = num_warps;
243
244
245
246
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
247
    if (M % (m_warp * kMPerWarp) != 0) {
248
      // Calculate how many warps we can use for M
249
      int max_m_warps = M / kMPerWarp;
250
251
252
253
254
255
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
256
  } else if (this->isFullCol()) {
257
258
    // Try to partition N first
    m_warp = 1;
259
    n_warp = num_warps;
260
261
262

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
263
    if (N % (n_warp * kNPerWarp) != 0) {
264
      // Calculate how many warps we can use for N
265
      int max_n_warps = N / kNPerWarp;
266
267
268
269
270
271
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
272
  } else if (this->isSquare()) {
273
    // First calculate the maximum possible warps for each dimension
274
    int max_m_warps =
275
        M / kMPerWarp; // Each warp needs at least 16 elements in M
276
277
278

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
279
280
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
281
282
283
284
285
286
287
288
289
290
291
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
292
293
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
294
295
296
297
298
299
300
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

301
302
303
304
305
306
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
307
308
      }
    }
309
310
311

    m_warp = best_m;
    n_warp = best_n;
312
313
314
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
315
316
317
318
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

319
320
321
322
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

323
324
325
  return {m_warp, n_warp};
}

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
342
343
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
344
345
346
347
348
349
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
350
351
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
352
353
354
355
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
356
357
bool GemmNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
358
359
360
    return false;
  }

361
362
363
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
364
    else if (a_->dtype.is_float8() && b_->dtype.is_float8())
365
      return (!transA_) && transB_ && k_ % 32 == 0;
366
367
    else
      return false;
368
369
370
371
372
373
374
375
376
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
377
    else if (a_->dtype.is_float8() && b_->dtype.is_float8())
378
      return (!transA_) && transB_ && k_ % 32 == 0;
379
380
    else
      return false;
381
382
383
384
385
386
387
388
389
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
390
391
392
393
394
395
396
    else
      return false;
  } else {
    return false;
  }
}

397
398
399
400
401
402
403
404
405
406
407
408
409
410
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
411
412
static int GetArchInt(Target target) {
  int arch_int = 0;
413
414
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
415
416
417
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
418
419
420
421
422
423
  } else {
    arch_int = 0;
  }
  return arch_int;
}

424
425
426
427
428
429
430
431
432
433
434
435
436
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
437
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
438
  auto block_size = *as_const_int(T.thread_bounds->extent);
439
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
440
  auto [warp_m, warp_n] =
441
442
443
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);

  // Build access pointers from regions locally
444
445
446
447
448
449
  PrimExpr Aptr =
      MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Bptr =
      MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Cptr =
      MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
450

451
  std::stringstream ss;
452
453
454
455
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
456
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
457
    ICHECK(can_use_tcgen5mma);
458
459
460
461
    ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
    ICHECK(c_.scope() == "shared.tmem");
    ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (a_.scope() == "shared.tmem") {
462
      op_name = "tl::tcgen5mma_gemm_ts";
463
    } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
464
465
466
467
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
468
          << a_.scope(); // If this is triggered, it means Tilelang has bugs.
469
    }
470
    ICHECK(wgWait_ == -1)
471
472
473
474
475
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
476
477
    if (c_->dtype.is_float()) {
      if (c_->dtype.bits() == 32) {
478
479
480
481
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
482
483
        << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype;
    ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
484
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
485
    ss << transA_ << ", " << transB_ << ", ";
486
487
488
    ss << accum_dtype;
    ss << ">";

489
    auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
490
    Array<PrimExpr> new_args;
491
492
    auto mbarPtr =
        MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
493
494
495
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
496
    new_args.push_back(BufferLoad(C_buffer, cCoords_));
497
    new_args.push_back(mbarPtr);
498
    new_args.push_back(clearAccum_);
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

523
524
525
  if (a_.scope() == "local.fragment") {
    ICHECK(b_.scope() != "local.fragment");
    ICHECK(!transA_)
526
        << "gemm_rs requires the A operand to be in non-transposed layout.";
527
    op_name = "tl::gemm_rs";
528
  } else if (b_.scope() == "local.fragment") {
529
    op_name = "tl::gemm_sr";
530
531
  } else {
    op_name = "tl::gemm_ss";
532
  }
533
  ICHECK(c_.scope() == "local.fragment");
534

535
  ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
536
  ss << warp_m << ", " << warp_n << ", ";
537
538
  ss << transA_ << ", " << transB_;
  auto clear_accum_bool = clearAccum_.as<Bool>();
539
  ICHECK(clear_accum_bool.has_value())
540
      << "clear_accum must be a constant Bool type, got " << clearAccum_;
541
  ss << ", " << bool(clear_accum_bool.value());
542
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
543
544
    ss << ", " << strideA_ << ", " << strideB_;
    ss << ", " << offsetA_ << ", " << offsetB_;
545
  }
546
547
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
548
    ss << ", " << kPack_;
549
  } else if (TargetIsHopper(T.target)) {
550
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
551
  }
552
553
554

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
555
556
    if (wgWait_ != 0) {
      ss << ", " << wgWait_;
557
558
559
560
561
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
562
    ICHECK(wgWait_ == 0 || wgWait_ == -1)
563
564
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
565
    ICHECK(wgWait_ == 0)
566
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
567
  }
568
  ss << ">";
569
570
571

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
572
573
574
  return Evaluate(new_call);
}

575
/**
576
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
577
 *
578
579
580
581
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
582
583
 *
 * Preconditions:
584
 * - C.scope() == "local.fragment"
585
 *
586
587
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
588
589
590
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
591
592
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
593
 */
594
595
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
596
597
  if (completed_)
    return {};
598
  LayoutMap results;
599
600
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
601
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
602
  auto [warp_m, warp_n] =
603
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
604
  if (TargetIsVolta(T.target)) {
605
    ICHECK(c_.scope() == "local.fragment")
606
        << "Volta gemm only supports C in local.fragment scope, got "
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        << c_.scope();
    auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                           c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]),
                                            *as_const_int(a_->shape[dim_A - 1]),
                                            true, !transA_));
    } else if (a_.scope() == "local.fragment") {
      ICHECK(transA_ == false);
      auto fragment =
          makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n);
      results.Set(a_, fragment->BindThreadRange(thread_range));
621
622
623
624
    } else {
      ICHECK(0);
    }

625
626
627
628
629
    ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn");
    int dim_B = b_->shape.size();
    results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]),
                                          *as_const_int(b_->shape[dim_B - 1]),
                                          false, transB_));
630
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
631
632
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
633
634
    ICHECK(c_.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << c_.scope();
635

636
    auto fragment =
637
638
639
640
641
642
643
644
        makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));

    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_,
645
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
646
647
648
649
650
                                   a_->dtype.bits(), !transA_));
    } else if (a_.scope() == "local.fragment") {
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
651
652
653
    } else {
      ICHECK(0);
    }
654
655
656
657
658
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
      results.Set(b_,
659
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
660
661
                                   b_->dtype.bits(), transB_));
    } else if (b_.scope() == "local.fragment") {
662
      auto fragment =
663
664
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
665
666
667
668
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
669
    ICHECK(c_.scope() == "local.fragment")
670
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
671
672
673
674
675
676
677
678
679
680
681
        << "only supports C in local.fragment scope, got " << c_.scope();
    auto fragment = gemm_inst == GemmInst::kWGMMA
                        ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
                                                  n_ / warp_n, c_->dtype.bits())
                        : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                            c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
682
      const int64_t continuity =
683
          transA_ ? 4 * mat_continuous / warp_m : mat_continuous;
684
      auto ABLayout =
685
          gemm_inst == GemmInst::kWGMMA
686
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
687
                                       a_->dtype.bits(), !transA_)
688
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
689
690
                                 a_->dtype.bits(), !transA_);
      results.Set(a_, ABLayout);
691
    } else {
692
693
694
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
695
    }
696
697
698
699
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
700
      const int64_t continuity =
701
          transB_ ? mat_continuous : mat_continuous / warp_n;
702

703
      auto ABLayout =
704
          gemm_inst == GemmInst::kWGMMA
705
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
706
                                       b_->dtype.bits(), transB_)
707
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
708
709
                                 b_->dtype.bits(), transB_);
      results.Set(b_, ABLayout);
710
    } else {
711
      auto fragment =
712
713
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
714
    }
715
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
716
717
718
    ICHECK(c_.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope();
    ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared")
719
720
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
721
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
722
723
    ICHECK(can_use_tcgen5mma);
    {
724
725
726
727
728
729
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                            mat_continuous, a_->dtype.bits(),
                                            transA_ ? 1 : 2));
730
731
    }
    {
732
733
734
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
735
      const int64_t continuity = mat_continuous;
736
      results.Set(b_,
737
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
738
                                        b_->dtype.bits(), transB_ ? 2 : 1));
739
740
741
    }
    {
      Layout res;
742
743
744
      IterVar i = make_itervar("i", m_);
      IterVar j = make_itervar("j", n_);
      ICHECK(m_ % meta.atom_m == 0);
745
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
746
                          FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m);
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
772
      results.Set(c_, res);
773
    }
774
  } else if (TargetIsCDNA(T.target)) {
775
    ICHECK(c_.scope() == "local.fragment")
776
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
777
778
779
780
        << c_.scope();
    auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n,
                                          c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
781

782
783
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
784
      auto shared_layout = makeGemmABLayoutCDNA(
785
786
787
788
789
790
791
792
          *as_const_int(a_->shape[dim_A - 2]),
          *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
      results.Set(a_, shared_layout);
    } else if (a_.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                a_->dtype.bits(), kPack_, transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
793
794
795
    } else {
      ICHECK(0);
    }
796
797
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
798
      auto shared_layout = makeGemmABLayoutCDNA(
799
800
          *as_const_int(b_->shape[dim_B - 2]),
          *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
801

802
803
      results.Set(b_, shared_layout);
    } else if (b_.scope() == "local.fragment") {
804
      auto fragment =
805
806
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
807
808
809
810
811
812
813
814
815
816
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

817
TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
818
    .set_num_inputs(5)
819
820
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
821

822
823
824
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

825
TVM_FFI_STATIC_INIT_BLOCK() {
826
827
828
829
830
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
831
                           Target target, GemmInst gemm_inst) {
832
                          policy->computeWarpPartition(M, N, block_size, target,
833
                                                       gemm_inst);
834
                        });
835
}
836

837
} // namespace tl
838
} // namespace tvm