"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5b8c054e9d8fffa864a825f9147cedc911b721fd"
gemm.cc 33.7 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14
15
16
17
18
19
20

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
struct TCGEN5MMAMeta {
  int atom_m, atom_n, atom_k;
};

// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL                                                                   \
  return {                                                                     \
    false, TCGEN5MMAMeta { 0, 0, 0 }                                           \
  }
#define SUCCESS(atom_m, atom_n, atom_k)                                        \
  return {                                                                     \
    true, TCGEN5MMAMeta { atom_m, atom_n, atom_k }                             \
  }
  std::vector<int> ws_valid_atom_ns = {256, 128, 64};
  if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
      (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 16 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 16);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 16);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 16);
      FAIL;
    } else {
      FAIL;
    }
  } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
             (c_dtype.is_float() && c_dtype.bits() == 32)) {
    if (K % 32 != 0)
      FAIL;
    if (M % 128 == 0) {
      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
        if (N % atom_n == 0)
          SUCCESS(128, atom_n, 32);
      FAIL;
    } else if (M % 64 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(64, atom_n, 32);
      FAIL;
    } else if (M % 32 == 0) {
      for (int atom_n : ws_valid_atom_ns)
        if (N % atom_n == 0)
          SUCCESS(32, atom_n, 32);
      FAIL;
    } else {
      FAIL;
    }
  }
  FAIL;
#undef FAIL
#undef SUCCESS
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
116
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
117
118
119
120
121
122
123
124
125
126
127
128
129
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
130
  node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
131
  node->clear_accum = args[9].as<PrimExpr>().value();
132
133
134
135
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
136
  if (args.size() > 14) {
137
138
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
139
140
141
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
142
  if (args.size() > 15) {
143
    node->wg_wait = args[15].as<IntImm>().value()->value;
144
  }
145
146
147
148
149
150
151
152
  node->mbarptr = args[16];
  if (node->mbarptr.as<CallNode>()) {
    node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
  } else {
    node->mbar = std::nullopt;
  }
  node->C_coords = Array<PrimExpr>(
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
153
  data_ = std::move(node);
154
155
}

156
157
158
159
160
161
162
163
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
164
165
166
167
168
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

169
170
171
172
173
174
175
176
177
178
bool GemmNode::AllowTCGEN5MMA(Target target) const {
  return TargetIsSm100(target) &&
         ((A.scope() == "shared.dyn" || A.scope() == "shared" ||
           A.scope() == "shared.tmem") &&
          (B.scope() == "shared.dyn" || B.scope() == "shared") &&
          C.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}

bool GemmNode::AllowWGMMA(int block_size, Target target) const {
179
180
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

181
182
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
183
184
185
186
187
188
189
190
191
192
193
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
         TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
         CheckWGMMA();
}

GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
  bool allow_tcgen5mma = AllowTCGEN5MMA(target);
  bool allow_wgmma = AllowWGMMA(block_size, target);
  if (allow_tcgen5mma) {
    return GemmInst::kTCGEN5MMA;
  } else if (allow_wgmma) {
194
195
196
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
197
198
199
  } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
             TargetIsTuring(target) || TargetIsHopper(target) ||
             TargetIsSm100(target)) {
200
201
202
203
204
205
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

206
207
std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
208
  int num_warps = block_size / TargetGetWarpSize(target);
209
210
211
212
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

213
  int m_warp = 1, n_warp = 1;
214
215
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
216
217
218
219
220
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

221
  if (gemm_inst == GemmInst::kWGMMA) {
222
223
224
225
226
227
228
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

229
    if (this->isFullRow()) {
230
231
232
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
233
        if (M % (cand * kMPerWarp) == 0) {
234
235
236
237
238
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
239
    } else if (this->isFullCol()) {
240
241
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
242
243
244
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
245
246
247
248
249
250
251
252
253
254
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
255
    } else if (this->isSquare()) {
256
      // Exhaustive search, but m must be multiple of 4
257
258
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
259

260
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
261
262
263
264
265
266
267
268
269
270
271

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

272
273
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
274
275
276
277
278
279
280
281
282
283
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
284
285
286
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
287
288
289

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
290
291
292
293
294

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

295
296
    return {m_warp, n_warp};
  }
297

298
  if (this->isFullRow()) {
299
    // Try to partition M first
300
    m_warp = num_warps;
301
302
303
304
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
305
    if (M % (m_warp * kMPerWarp) != 0) {
306
      // Calculate how many warps we can use for M
307
      int max_m_warps = M / kMPerWarp;
308
309
310
311
312
313
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
314
  } else if (this->isFullCol()) {
315
316
    // Try to partition N first
    m_warp = 1;
317
    n_warp = num_warps;
318
319
320

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
321
    if (N % (n_warp * kNPerWarp) != 0) {
322
      // Calculate how many warps we can use for N
323
      int max_n_warps = N / kNPerWarp;
324
325
326
327
328
329
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
330
  } else if (this->isSquare()) {
331
    // First calculate the maximum possible warps for each dimension
332
    int max_m_warps =
333
        M / kMPerWarp; // Each warp needs at least 16 elements in M
334
335
336

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
337
338
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
339
340
341
342
343
344
345
346
347
348
349
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
350
351
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
352
353
354
355
356
357
358
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

359
360
361
362
363
364
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
365
366
      }
    }
367
368
369

    m_warp = best_m;
    n_warp = best_n;
370
371
372
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
373
374
375
376
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

377
378
379
  return {m_warp, n_warp};
}

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
396
397
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
398
399
400
401
402
403
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
404
405
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
406
407
408
409
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
410
bool GemmNode::CheckWGMMA() const {
411
412
413
414
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

415
416
417
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
418
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
419
      return (!trans_A) && trans_B && K % 32 == 0;
420
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
421
      return (!trans_A) && trans_B && K % 32 == 0;
422
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
423
      return (!trans_A) && trans_B && K % 32 == 0;
424
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
425
426
427
428
429
430
431
432
433
434
435
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
436
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
437
      return (!trans_A) && trans_B && K % 32 == 0;
438
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
439
      return (!trans_A) && trans_B && K % 32 == 0;
440
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
441
      return (!trans_A) && trans_B && K % 32 == 0;
442
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

462
463
464
465
466
467
468
469
470
471
472
473
474
475
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
476
477
478
479
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
480
481
482
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
483
484
485
486
487
488
  } else {
    arch_int = 0;
  }
  return arch_int;
}

489
490
491
492
493
494
495
496
497
498
499
500
501
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
502
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
503
  auto block_size = *as_const_int(T.thread_bounds->extent);
504
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
505
506
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
507

508
  std::stringstream ss;
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared");
    ICHECK(C.scope() == "shared.tmem");
    ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (A.scope() == "shared.tmem") {
      op_name = "tl::tcgen5mma_gemm_ts";
    } else if (A.scope() == "shared.dyn" || A.scope() == "shared") {
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
          << A.scope(); // If this is triggered, it means Tilelang has bugs.
    }
    ICHECK(wg_wait == -1)
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
    if (C->dtype.is_float()) {
      if (C->dtype.bits() == 32) {
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
        << "Unsupported C dtype for TCGEN5MMA: " << C->dtype;
    ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
    ss << trans_A << ", " << trans_B << ", ";
    ss << accum_dtype;
    ss << ">";

    auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
    Array<PrimExpr> new_args;
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
    new_args.push_back(BufferLoad(C_buffer, C_coords));
    new_args.push_back(mbarptr);
    new_args.push_back(clear_accum);
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

578
579
580
581
582
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
583
584
  } else {
    op_name = "tl::gemm_ss";
585
  }
586
587
  ICHECK(C.scope() == "local.fragment");

588
589
590
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
591
592
593
594
  auto clear_accum_bool = clear_accum.as<Bool>();
  ICHECK(clear_accum_bool.has_value())
      << "clear_accum must be a constant Bool type, got " << clear_accum;
  ss << ", " << bool(clear_accum_bool.value());
595
596
597
598
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
599
600
601
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
602
  } else if (TargetIsHopper(T.target)) {
603
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
604
  }
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
    if (wg_wait != 0) {
      ss << ", " << wg_wait;
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
    ICHECK(wg_wait == 0 || wg_wait == -1)
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
    ICHECK(wg_wait == 0)
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
620
  }
621
  ss << ">";
622
623
624

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
625
626
627
  return Evaluate(new_call);
}

628
/**
629
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
630
 *
631
632
633
634
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
635
636
 *
 * Preconditions:
637
 * - C.scope() == "local.fragment"
638
 *
639
640
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
641
642
643
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
644
645
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
646
 */
647
648
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
649
650
  if (completed_)
    return {};
651
  LayoutMap results;
652
653
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
654
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
655
656
  auto [warp_m, warp_n] =
      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
657
  if (TargetIsVolta(T.target)) {
658
659
660
    ICHECK(C.scope() == "local.fragment")
        << "Volta gemm only supports C in local.fragment scope, got "
        << C.scope();
661
662
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
663
    results.Set(C, fragment->BindThreadRange(thread_range));
664
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
665
666
667
668
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
669
670
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
671
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
672
      results.Set(A, fragment->BindThreadRange(thread_range));
673
674
675
676
677
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
678
679
680
681
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
682
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
683
684
685
686
687
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
    ICHECK(C.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << C.scope();

688
689
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
690
    results.Set(C, fragment->BindThreadRange(thread_range));
691
692

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
693
694
695
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
696
697
698
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
699
    } else if (A.scope() == "local.fragment") {
700
701
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
702
      results.Set(A, fragment->BindThreadRange(thread_range));
703
704
705
706
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
707
708
709
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
710
711
712
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
713
    } else if (B.scope() == "local.fragment") {
714
715
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
716
      results.Set(B, fragment->BindThreadRange(thread_range));
717
718
719
720
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
721
722
723
    ICHECK(C.scope() == "local.fragment")
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
        << "only supports C in local.fragment scope, got " << C.scope();
724
    auto fragment =
725
        gemm_inst == GemmInst::kWGMMA
726
727
728
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
729
    results.Set(C, fragment->BindThreadRange(thread_range));
730
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
731
732
733
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
734
      const int64_t continuity =
735
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
736
      auto ABLayout =
737
          gemm_inst == GemmInst::kWGMMA
738
739
740
741
742
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
743
    } else {
744
745
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
746
      results.Set(A, fragment->BindThreadRange(thread_range));
747
748
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
749
750
751
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
752
753
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
754
      auto ABLayout =
755
          gemm_inst == GemmInst::kWGMMA
756
757
758
759
760
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
761
    } else {
762
763
764
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
765
    }
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
    ICHECK(C.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope();
    ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared")
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
        GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
    ICHECK(can_use_tcgen5mma);
    {
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                           mat_continuous, A->dtype.bits(),
                                           trans_A ? 1 : 2));
    }
    {
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
      const int64_t continuity = mat_continuous;
      results.Set(B,
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
                                        B->dtype.bits(), trans_B ? 2 : 1));
    }
    {
      Layout res;
      IterVar i = make_itervar("i", M);
      IterVar j = make_itervar("j", N);
      ICHECK(M % meta.atom_m == 0);
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
                          FloorDiv(j, meta.atom_n) * (M / meta.atom_m);
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
      results.Set(C, res);
    }
825
  } else if (TargetIsCDNA(T.target)) {
826
827
828
    ICHECK(C.scope() == "local.fragment")
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
        << C.scope();
829
830
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
831
    results.Set(C, fragment->BindThreadRange(thread_range));
832
833

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
834
835
836
837
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
838
839
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
840
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
841
                                            A->dtype.bits(), kPack, trans_A);
842
      results.Set(A, fragment->BindThreadRange(thread_range));
843
844
845
846
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
847
848
849
850
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
851
852

      results.Set(B, shared_layout);
853
854
855
856
    } else if (B.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
857
858
859
860
861
862
863
864
865
866
867
868
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
869
870
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
871

872
873
874
875
876
877
878
879
880
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

TVM_FFI_STATIC_INIT_BLOCK({
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
881
                           Target target, GemmInst gemm_inst) {
882
                          policy->ComputeWarpPartition(M, N, block_size, target,
883
                                                       gemm_inst);
884
885
886
887
                          return;
                        });
});

888
} // namespace tl
889
} // namespace tvm