gemm.cc 33.9 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14
15
16
17
18
19
20

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
struct TCGEN5MMAMeta {
  int atom_m, atom_n, atom_k;
};

// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL                                                                   \
  return {                                                                     \
    false, TCGEN5MMAMeta { 0, 0, 0 }                                           \
  }
#define SUCCESS(atom_m, atom_n, atom_k)                                        \
  return {                                                                     \
    true, TCGEN5MMAMeta { atom_m, atom_n, atom_k }                             \
  }
  std::vector<int> ws_valid_atom_ns = {256, 128, 64};
  if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
      (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 16 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 16);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 16);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 16);
      FAIL;
    } else {
      FAIL;
    }
  } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
             (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 32 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 32);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 32);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 32);
      FAIL;
    } else {
      FAIL;
    }
  }
  FAIL;
#undef FAIL
#undef SUCCESS
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
112
 * @note If `kPack` is provided it must be 1; otherwise the constructor
113
114
115
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
116
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
117
118
119
120
121
122
123
124
125
126
127
128
129
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
130
  node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
131
  node->clear_accum = args[9].as<PrimExpr>().value();
132
133
134
135
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
136
  if (args.size() > 14) {
137
138
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
139
140
141
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
142
  if (args.size() > 15) {
143
    node->wg_wait = args[15].as<IntImm>().value()->value;
144
  }
145
146
147
148
149
150
151
152
  node->mbarptr = args[16];
  if (node->mbarptr.as<CallNode>()) {
    node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
  } else {
    node->mbar = std::nullopt;
  }
  node->C_coords = Array<PrimExpr>(
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
153
  data_ = std::move(node);
154
155
}

156
157
158
159
160
161
162
163
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
164
165
166
167
168
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

169
170
171
172
173
174
175
176
177
178
bool GemmNode::AllowTCGEN5MMA(Target target) const {
  return TargetIsSm100(target) &&
         ((A.scope() == "shared.dyn" || A.scope() == "shared" ||
           A.scope() == "shared.tmem") &&
          (B.scope() == "shared.dyn" || B.scope() == "shared") &&
          C.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}

bool GemmNode::AllowWGMMA(int block_size, Target target) const {
179
180
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

181
182
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
183
184
185
186
187
188
189
190
191
192
193
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
         TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
         CheckWGMMA();
}

GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
  bool allow_tcgen5mma = AllowTCGEN5MMA(target);
  bool allow_wgmma = AllowWGMMA(block_size, target);
  if (allow_tcgen5mma) {
    return GemmInst::kTCGEN5MMA;
  } else if (allow_wgmma) {
194
195
196
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
197
198
199
  } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
             TargetIsTuring(target) || TargetIsHopper(target) ||
             TargetIsSm100(target)) {
200
201
202
203
204
205
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

206
207
std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
208
  int num_warps = block_size / TargetGetWarpSize(target);
209
210
211
212
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

213
  int m_warp = 1, n_warp = 1;
214
215
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
216
217
218
219
220
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

221
  if (gemm_inst == GemmInst::kWGMMA) {
222
223
224
225
226
227
228
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

229
    if (this->isFullRow()) {
230
231
232
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
233
        if (M % (cand * kMPerWarp) == 0) {
234
235
236
237
238
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
239
    } else if (this->isFullCol()) {
240
241
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
242
243
244
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
245
246
247
248
249
250
251
252
253
254
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
255
    } else if (this->isSquare()) {
256
      // Exhaustive search, but m must be multiple of 4
257
258
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
259

260
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
261
262
263
264
265
266
267
268
269
270
271

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

272
273
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
274
275
276
277
278
279
280
281
282
283
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
284
285
286
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
287
288

    ICHECK(m_warp * n_warp == num_warps)
289
290
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
291
292
293
294
295

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

296
297
    return {m_warp, n_warp};
  }
298

299
  if (this->isFullRow()) {
300
    // Try to partition M first
301
    m_warp = num_warps;
302
303
304
305
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
306
    if (M % (m_warp * kMPerWarp) != 0) {
307
      // Calculate how many warps we can use for M
308
      int max_m_warps = M / kMPerWarp;
309
310
311
312
313
314
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
315
  } else if (this->isFullCol()) {
316
317
    // Try to partition N first
    m_warp = 1;
318
    n_warp = num_warps;
319
320
321

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
322
    if (N % (n_warp * kNPerWarp) != 0) {
323
      // Calculate how many warps we can use for N
324
      int max_n_warps = N / kNPerWarp;
325
326
327
328
329
330
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
331
  } else if (this->isSquare()) {
332
    // First calculate the maximum possible warps for each dimension
333
    int max_m_warps =
334
        M / kMPerWarp; // Each warp needs at least 16 elements in M
335
336
337

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
338
339
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
340
341
342
343
344
345
346
347
348
349
350
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
351
352
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
353
354
355
356
357
358
359
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

360
361
362
363
364
365
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
366
367
      }
    }
368
369
370

    m_warp = best_m;
    n_warp = best_n;
371
372
373
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
374
375
376
377
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

378
379
380
381
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

382
383
384
  return {m_warp, n_warp};
}

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
401
402
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
403
404
405
406
407
408
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
409
410
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
411
412
413
414
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
415
bool GemmNode::CheckWGMMA() const {
416
417
418
419
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

420
421
422
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
423
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
424
      return (!trans_A) && trans_B && K % 32 == 0;
425
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
426
      return (!trans_A) && trans_B && K % 32 == 0;
427
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
428
      return (!trans_A) && trans_B && K % 32 == 0;
429
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
430
431
432
433
434
435
436
437
438
439
440
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
441
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
442
      return (!trans_A) && trans_B && K % 32 == 0;
443
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
444
      return (!trans_A) && trans_B && K % 32 == 0;
445
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
446
      return (!trans_A) && trans_B && K % 32 == 0;
447
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

467
468
469
470
471
472
473
474
475
476
477
478
479
480
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
481
482
483
484
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
485
486
487
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
488
489
490
491
492
493
  } else {
    arch_int = 0;
  }
  return arch_int;
}

494
495
496
497
498
499
500
501
502
503
504
505
506
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
507
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
508
  auto block_size = *as_const_int(T.thread_bounds->extent);
509
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
510
511
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
512

513
  std::stringstream ss;
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared");
    ICHECK(C.scope() == "shared.tmem");
    ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (A.scope() == "shared.tmem") {
      op_name = "tl::tcgen5mma_gemm_ts";
    } else if (A.scope() == "shared.dyn" || A.scope() == "shared") {
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
          << A.scope(); // If this is triggered, it means Tilelang has bugs.
    }
    ICHECK(wg_wait == -1)
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
    if (C->dtype.is_float()) {
      if (C->dtype.bits() == 32) {
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
        << "Unsupported C dtype for TCGEN5MMA: " << C->dtype;
    ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
    ss << trans_A << ", " << trans_B << ", ";
    ss << accum_dtype;
    ss << ">";

    auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
    Array<PrimExpr> new_args;
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
    new_args.push_back(BufferLoad(C_buffer, C_coords));
    new_args.push_back(mbarptr);
    new_args.push_back(clear_accum);
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

583
584
585
586
587
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
588
589
  } else {
    op_name = "tl::gemm_ss";
590
  }
591
592
  ICHECK(C.scope() == "local.fragment");

593
594
595
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
596
597
598
599
  auto clear_accum_bool = clear_accum.as<Bool>();
  ICHECK(clear_accum_bool.has_value())
      << "clear_accum must be a constant Bool type, got " << clear_accum;
  ss << ", " << bool(clear_accum_bool.value());
600
601
602
603
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
604
605
606
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
607
  } else if (TargetIsHopper(T.target)) {
608
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
609
  }
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
    if (wg_wait != 0) {
      ss << ", " << wg_wait;
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
    ICHECK(wg_wait == 0 || wg_wait == -1)
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
    ICHECK(wg_wait == 0)
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
625
  }
626
  ss << ">";
627
628
629

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
630
631
632
  return Evaluate(new_call);
}

633
/**
634
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
635
 *
636
637
638
639
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
640
641
 *
 * Preconditions:
642
 * - C.scope() == "local.fragment"
643
 *
644
645
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
646
647
648
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
649
650
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
651
 */
652
653
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
654
655
  if (completed_)
    return {};
656
  LayoutMap results;
657
658
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
659
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
660
661
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
662
  if (TargetIsVolta(T.target)) {
663
664
665
    ICHECK(C.scope() == "local.fragment")
        << "Volta gemm only supports C in local.fragment scope, got "
        << C.scope();
666
667
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
668
    results.Set(C, fragment->BindThreadRange(thread_range));
669
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
670
671
672
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
673
                                           true, !trans_A));
674
675
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
676
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
677
      results.Set(A, fragment->BindThreadRange(thread_range));
678
679
680
681
682
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
683
684
685
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
686
                                         false, trans_B));
687
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
688
689
690
691
692
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
    ICHECK(C.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << C.scope();

693
694
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
695
    results.Set(C, fragment->BindThreadRange(thread_range));
696
697

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
698
699
700
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
701
702
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
703
                                   A->dtype.bits(), !trans_A));
704
    } else if (A.scope() == "local.fragment") {
705
706
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
707
      results.Set(A, fragment->BindThreadRange(thread_range));
708
709
710
711
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
712
713
714
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
715
716
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
717
                                   B->dtype.bits(), trans_B));
718
    } else if (B.scope() == "local.fragment") {
719
720
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
721
      results.Set(B, fragment->BindThreadRange(thread_range));
722
723
724
725
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
726
727
728
    ICHECK(C.scope() == "local.fragment")
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
        << "only supports C in local.fragment scope, got " << C.scope();
729
    auto fragment =
730
        gemm_inst == GemmInst::kWGMMA
731
732
733
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
734
    results.Set(C, fragment->BindThreadRange(thread_range));
735
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
736
737
738
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
739
      const int64_t continuity =
740
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
741
      auto ABLayout =
742
          gemm_inst == GemmInst::kWGMMA
743
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
744
                                       A->dtype.bits(), !trans_A)
745
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
746
                                 A->dtype.bits(), !trans_A);
747
      results.Set(A, ABLayout);
748
    } else {
749
750
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
751
      results.Set(A, fragment->BindThreadRange(thread_range));
752
753
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
754
755
756
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
757
758
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
759

760
      auto ABLayout =
761
          gemm_inst == GemmInst::kWGMMA
762
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
763
                                       B->dtype.bits(), trans_B)
764
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
765
                                 B->dtype.bits(), trans_B);
766
      results.Set(B, ABLayout);
767
    } else {
768
769
770
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
771
    }
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
    ICHECK(C.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope();
    ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared")
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    {
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                           mat_continuous, A->dtype.bits(),
                                           trans_A ? 1 : 2));
    }
    {
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
      const int64_t continuity = mat_continuous;
      results.Set(B,
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
                                        B->dtype.bits(), trans_B ? 2 : 1));
    }
    {
      Layout res;
      IterVar i = make_itervar("i", M);
      IterVar j = make_itervar("j", N);
      ICHECK(M % meta.atom_m == 0);
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
                          FloorDiv(j, meta.atom_n) * (M / meta.atom_m);
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
      results.Set(C, res);
    }
831
  } else if (TargetIsCDNA(T.target)) {
832
833
834
    ICHECK(C.scope() == "local.fragment")
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
        << C.scope();
835
836
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
837
    results.Set(C, fragment->BindThreadRange(thread_range));
838
839

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
840
841
842
843
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
844
845
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
846
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
847
                                            A->dtype.bits(), kPack, trans_A);
848
      results.Set(A, fragment->BindThreadRange(thread_range));
849
850
851
852
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
853
854
855
856
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
857
858

      results.Set(B, shared_layout);
859
860
861
862
    } else if (B.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
863
864
865
866
867
868
869
870
871
872
873
874
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
875
876
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
877

878
879
880
881
882
883
884
885
886
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

TVM_FFI_STATIC_INIT_BLOCK({
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
887
                           Target target, GemmInst gemm_inst) {
888
                          policy->ComputeWarpPartition(M, N, block_size, target,
889
                                                       gemm_inst);
890
891
892
893
                          return;
                        });
});

894
} // namespace tl
895
} // namespace tvm