"docs/archive_en_US/Tutorial/Contributing.md" did not exist on "bbd441b3aff8d7f836f2be12c2589b4c4bba6e67"
gemm.cc 33.7 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14
15
16
17
18
19
20

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
struct TCGEN5MMAMeta {
  int atom_m, atom_n, atom_k;
};

// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL                                                                   \
  return {                                                                     \
    false, TCGEN5MMAMeta { 0, 0, 0 }                                           \
  }
#define SUCCESS(atom_m, atom_n, atom_k)                                        \
  return {                                                                     \
    true, TCGEN5MMAMeta { atom_m, atom_n, atom_k }                             \
  }
  std::vector<int> ws_valid_atom_ns = {256, 128, 64};
  if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
      (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 16 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 16);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 16);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 16);
      FAIL;
    } else {
      FAIL;
    }
  } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
             (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 32 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 32);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 32);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 32);
      FAIL;
    } else {
      FAIL;
    }
  }
  FAIL;
#undef FAIL
#undef SUCCESS
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
112
 * @note If `kPack` is provided it must be 1; otherwise the constructor
113
114
115
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
116
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
117
118
119
120
121
122
123
124
125
126
127
128
129
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
130
  node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
131
  node->clear_accum = args[9].as<PrimExpr>().value();
132
133
134
135
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
136
  if (args.size() > 14) {
137
138
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
139
140
141
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
142
  if (args.size() > 15) {
143
    node->wg_wait = args[15].as<IntImm>().value()->value;
144
  }
145
146
147
148
149
150
151
152
  node->mbarptr = args[16];
  if (node->mbarptr.as<CallNode>()) {
    node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
  } else {
    node->mbar = std::nullopt;
  }
  node->C_coords = Array<PrimExpr>(
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
153
  data_ = std::move(node);
154
155
}

156
157
158
159
160
161
162
163
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
164
165
166
167
168
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

169
170
171
172
173
174
175
176
177
178
bool GemmNode::AllowTCGEN5MMA(Target target) const {
  return TargetIsSm100(target) &&
         ((A.scope() == "shared.dyn" || A.scope() == "shared" ||
           A.scope() == "shared.tmem") &&
          (B.scope() == "shared.dyn" || B.scope() == "shared") &&
          C.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}

bool GemmNode::AllowWGMMA(int block_size, Target target) const {
179
180
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

181
182
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
183
184
185
186
187
188
189
190
191
192
193
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
         TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
         CheckWGMMA();
}

GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
  bool allow_tcgen5mma = AllowTCGEN5MMA(target);
  bool allow_wgmma = AllowWGMMA(block_size, target);
  if (allow_tcgen5mma) {
    return GemmInst::kTCGEN5MMA;
  } else if (allow_wgmma) {
194
195
196
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
197
  } else if (TargetIsCuda(target)) {
198
199
200
201
202
203
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

204
205
std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
206
  int num_warps = block_size / TargetGetWarpSize(target);
207
208
209
210
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

211
  int m_warp = 1, n_warp = 1;
212
213
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
214
215
216
217
218
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

219
  if (gemm_inst == GemmInst::kWGMMA) {
220
221
222
223
224
225
226
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

227
    if (this->isFullRow()) {
228
229
230
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
231
        if (M % (cand * kMPerWarp) == 0) {
232
233
234
235
236
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
237
    } else if (this->isFullCol()) {
238
239
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
240
241
242
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
243
244
245
246
247
248
249
250
251
252
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
253
    } else if (this->isSquare()) {
254
      // Exhaustive search, but m must be multiple of 4
255
256
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
257

258
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
259
260
261
262
263
264
265
266
267
268
269

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

270
271
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
272
273
274
275
276
277
278
279
280
281
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
282
283
284
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
285
286

    ICHECK(m_warp * n_warp == num_warps)
287
288
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
289
290
291
292
293

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

294
295
    return {m_warp, n_warp};
  }
296

297
  if (this->isFullRow()) {
298
    // Try to partition M first
299
    m_warp = num_warps;
300
301
302
303
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
304
    if (M % (m_warp * kMPerWarp) != 0) {
305
      // Calculate how many warps we can use for M
306
      int max_m_warps = M / kMPerWarp;
307
308
309
310
311
312
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
313
  } else if (this->isFullCol()) {
314
315
    // Try to partition N first
    m_warp = 1;
316
    n_warp = num_warps;
317
318
319

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
320
    if (N % (n_warp * kNPerWarp) != 0) {
321
      // Calculate how many warps we can use for N
322
      int max_n_warps = N / kNPerWarp;
323
324
325
326
327
328
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
329
  } else if (this->isSquare()) {
330
    // First calculate the maximum possible warps for each dimension
331
    int max_m_warps =
332
        M / kMPerWarp; // Each warp needs at least 16 elements in M
333
334
335

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
336
337
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
338
339
340
341
342
343
344
345
346
347
348
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
349
350
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
351
352
353
354
355
356
357
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

358
359
360
361
362
363
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
364
365
      }
    }
366
367
368

    m_warp = best_m;
    n_warp = best_n;
369
370
371
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
372
373
374
375
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

376
377
378
379
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

380
381
382
  return {m_warp, n_warp};
}

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
399
400
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
401
402
403
404
405
406
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
407
408
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
409
410
411
412
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
413
bool GemmNode::CheckWGMMA() const {
414
415
416
417
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

418
419
420
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
421
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
422
      return (!trans_A) && trans_B && K % 32 == 0;
423
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
424
      return (!trans_A) && trans_B && K % 32 == 0;
425
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
426
      return (!trans_A) && trans_B && K % 32 == 0;
427
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
428
429
430
431
432
433
434
435
436
437
438
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
439
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
440
      return (!trans_A) && trans_B && K % 32 == 0;
441
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
442
      return (!trans_A) && trans_B && K % 32 == 0;
443
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
444
      return (!trans_A) && trans_B && K % 32 == 0;
445
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

465
466
467
468
469
470
471
472
473
474
475
476
477
478
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
479
480
481
482
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
483
484
485
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
486
487
488
489
490
491
  } else {
    arch_int = 0;
  }
  return arch_int;
}

492
493
494
495
496
497
498
499
500
501
502
503
504
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
505
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
506
  auto block_size = *as_const_int(T.thread_bounds->extent);
507
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
508
509
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
510

511
  std::stringstream ss;
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared");
    ICHECK(C.scope() == "shared.tmem");
    ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (A.scope() == "shared.tmem") {
      op_name = "tl::tcgen5mma_gemm_ts";
    } else if (A.scope() == "shared.dyn" || A.scope() == "shared") {
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
          << A.scope(); // If this is triggered, it means Tilelang has bugs.
    }
    ICHECK(wg_wait == -1)
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
    if (C->dtype.is_float()) {
      if (C->dtype.bits() == 32) {
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
        << "Unsupported C dtype for TCGEN5MMA: " << C->dtype;
    ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
    ss << trans_A << ", " << trans_B << ", ";
    ss << accum_dtype;
    ss << ">";

    auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
    Array<PrimExpr> new_args;
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
    new_args.push_back(BufferLoad(C_buffer, C_coords));
    new_args.push_back(mbarptr);
    new_args.push_back(clear_accum);
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

581
582
583
584
585
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
586
587
  } else {
    op_name = "tl::gemm_ss";
588
  }
589
590
  ICHECK(C.scope() == "local.fragment");

591
592
593
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
594
595
596
597
  auto clear_accum_bool = clear_accum.as<Bool>();
  ICHECK(clear_accum_bool.has_value())
      << "clear_accum must be a constant Bool type, got " << clear_accum;
  ss << ", " << bool(clear_accum_bool.value());
598
599
600
601
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
602
603
604
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
605
  } else if (TargetIsHopper(T.target)) {
606
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
607
  }
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
    if (wg_wait != 0) {
      ss << ", " << wg_wait;
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
    ICHECK(wg_wait == 0 || wg_wait == -1)
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
    ICHECK(wg_wait == 0)
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
623
  }
624
  ss << ">";
625
626
627

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
628
629
630
  return Evaluate(new_call);
}

631
/**
632
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
633
 *
634
635
636
637
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
638
639
 *
 * Preconditions:
640
 * - C.scope() == "local.fragment"
641
 *
642
643
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
644
645
646
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
647
648
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
649
 */
650
651
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
652
653
  if (completed_)
    return {};
654
  LayoutMap results;
655
656
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
657
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
658
659
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
660
  if (TargetIsVolta(T.target)) {
661
662
663
    ICHECK(C.scope() == "local.fragment")
        << "Volta gemm only supports C in local.fragment scope, got "
        << C.scope();
664
665
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
666
    results.Set(C, fragment->BindThreadRange(thread_range));
667
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
668
669
670
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
671
                                           true, !trans_A));
672
673
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
674
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
675
      results.Set(A, fragment->BindThreadRange(thread_range));
676
677
678
679
680
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
681
682
683
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
684
                                         false, trans_B));
685
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
686
687
688
689
690
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
    ICHECK(C.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << C.scope();

691
692
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
693
    results.Set(C, fragment->BindThreadRange(thread_range));
694
695

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
696
697
698
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
699
700
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
701
                                   A->dtype.bits(), !trans_A));
702
    } else if (A.scope() == "local.fragment") {
703
704
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
705
      results.Set(A, fragment->BindThreadRange(thread_range));
706
707
708
709
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
710
711
712
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
713
714
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
715
                                   B->dtype.bits(), trans_B));
716
    } else if (B.scope() == "local.fragment") {
717
718
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
719
      results.Set(B, fragment->BindThreadRange(thread_range));
720
721
722
723
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
724
725
726
    ICHECK(C.scope() == "local.fragment")
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
        << "only supports C in local.fragment scope, got " << C.scope();
727
    auto fragment =
728
        gemm_inst == GemmInst::kWGMMA
729
730
731
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
732
    results.Set(C, fragment->BindThreadRange(thread_range));
733
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
734
735
736
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
737
      const int64_t continuity =
738
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
739
      auto ABLayout =
740
          gemm_inst == GemmInst::kWGMMA
741
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
742
                                       A->dtype.bits(), !trans_A)
743
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
744
                                 A->dtype.bits(), !trans_A);
745
      results.Set(A, ABLayout);
746
    } else {
747
748
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
749
      results.Set(A, fragment->BindThreadRange(thread_range));
750
751
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
752
753
754
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
755
756
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
757

758
      auto ABLayout =
759
          gemm_inst == GemmInst::kWGMMA
760
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
761
                                       B->dtype.bits(), trans_B)
762
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
763
                                 B->dtype.bits(), trans_B);
764
      results.Set(B, ABLayout);
765
    } else {
766
767
768
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
769
    }
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
    ICHECK(C.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope();
    ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared")
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    {
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                           mat_continuous, A->dtype.bits(),
                                           trans_A ? 1 : 2));
    }
    {
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
      const int64_t continuity = mat_continuous;
      results.Set(B,
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
                                        B->dtype.bits(), trans_B ? 2 : 1));
    }
    {
      Layout res;
      IterVar i = make_itervar("i", M);
      IterVar j = make_itervar("j", N);
      ICHECK(M % meta.atom_m == 0);
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
                          FloorDiv(j, meta.atom_n) * (M / meta.atom_m);
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
      results.Set(C, res);
    }
829
  } else if (TargetIsCDNA(T.target)) {
830
831
832
    ICHECK(C.scope() == "local.fragment")
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
        << C.scope();
833
834
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
835
    results.Set(C, fragment->BindThreadRange(thread_range));
836
837

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
838
839
840
841
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
842
843
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
844
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
845
                                            A->dtype.bits(), kPack, trans_A);
846
      results.Set(A, fragment->BindThreadRange(thread_range));
847
848
849
850
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
851
852
853
854
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
855
856

      results.Set(B, shared_layout);
857
858
859
860
    } else if (B.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
861
862
863
864
865
866
867
868
869
870
871
872
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
873
874
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
875

876
877
878
879
880
881
882
883
884
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

TVM_FFI_STATIC_INIT_BLOCK({
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
885
                           Target target, GemmInst gemm_inst) {
886
                          policy->ComputeWarpPartition(M, N, block_size, target,
887
                                                       gemm_inst);
888
889
890
891
                          return;
                        });
});

892
} // namespace tl
893
} // namespace tvm