gemm.cc 32.3 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
 */

#include "gemm.h"
7
#include "builtin.h"
8
#include <fstream>
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14

#include "../target/utils.h"
15
#include "tcgen5_meta.h"
16
#include "utils.h"
17
18
19
20
21
22

namespace tvm {
namespace tl {

using namespace tir;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 *
45
 * @note If `kPack` is provided it must be 1; otherwise the constructor
46
47
48
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
49
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
50

51
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
52

53
Gemm::Gemm(Array<PrimExpr> args) {
54
  ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
55

56
57
58
  node->aRegion_ = NormalizeToBufferRegion(args[0]);
  node->bRegion_ = NormalizeToBufferRegion(args[1]);
  node->cRegion_ = NormalizeToBufferRegion(args[2]);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
74
  if (args.size() > 14) {
75
76
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
77
78
79
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
80
  if (args.size() > 15) {
81
    node->wgWait_ = args[15].as<IntImm>().value()->value;
82
  }
83
84
85
86
87
88
89
90
  if (args.size() > 16) {
    if (const auto *load = args[16].as<BufferLoadNode>()) {
      node->mbarRegion_ =
          NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
      node->mbar_ = node->mbarRegion_->buffer;
    } else {
      node->mbar_ = std::nullopt;
    }
91
  }
92
  node->cCoords_ = Array<PrimExpr>(
93
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
94
  data_ = std::move(node);
95
96
}

97
98
99
100
101
102
103
104
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
105
TileOperator GemmNode::Clone() const {
106
  auto op = tvm::ffi::make_object<GemmNode>(*this);
107
108
109
  return Gemm(op);
}

110
bool GemmNode::allowTcgen5Mma(Target target) const {
111
  return TargetIsSm100(target) &&
112
113
114
115
116
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
117
118
}

119
bool GemmNode::allowWgmma(int block_size, Target target) const {
120
121
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

122
123
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
124
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
125
126
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
127
128
}

129
130
GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
  if (allowTcgen5Mma(target)) {
131
    return GemmInst::kTCGEN5MMA;
132
  } else if (allowWgmma(block_size, target)) {
133
    return GemmInst::kWGMMA;
134
135
  } else if(TargetIsDCU(target)) {
    return GemmInst::KMMAC;
136
137
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
138
  } else if (TargetIsCuda(target)) {
139
140
    return GemmInst::kMMA;
  } else {
141
    ICHECK(0) << "Unsupported target for gemm: " << target;
142
    return GemmInst::kMMA;
143
144
145
  }
}

146
std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
147
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
148
  int num_warps = block_size / TargetGetWarpSize(target);
149
150
151
152
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

153
  int m_warp = 1, n_warp = 1;
154
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
155
156
157
158
  int kNPerWarp = 8;            // Columns processed by a single warp
  if (TargetIsVolta(target)) {
    kNPerWarp = 16;
  }
159
160
161
162
163
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

164
  if (gemm_inst == GemmInst::kWGMMA) {
165
166
167
168
169
170
171
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

172
    if (this->isFullRow()) {
173
174
175
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
176
        if (M % (cand * kMPerWarp) == 0) {
177
178
179
180
181
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
182
    } else if (this->isFullCol()) {
183
184
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
185
186
187
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
188
189
190
191
192
193
194
195
196
197
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
198
    } else if (this->isSquare()) {
199
      // Exhaustive search, but m must be multiple of 4
200
201
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
202

203
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
204
205
206
207
208
209
210
211
212
213
214

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

215
216
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
217
218
219
220
221
222
223
224
225
226
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
227
228
229
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
230
231

    ICHECK(m_warp * n_warp == num_warps)
232
233
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
234
235
236
237
238

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

239
240
    return {m_warp, n_warp};
  }
241

242
  if (this->isFullRow()) {
243
    // Try to partition M first
244
    m_warp = num_warps;
245
246
247
248
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
249
    if (M % (m_warp * kMPerWarp) != 0) {
250
      // Calculate how many warps we can use for M
251
      int max_m_warps = M / kMPerWarp;
252
253
254
255
256
257
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
258
  } else if (this->isFullCol()) {
259
260
    // Try to partition N first
    m_warp = 1;
261
    n_warp = num_warps;
262
263
264

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
265
    if (N % (n_warp * kNPerWarp) != 0) {
266
      // Calculate how many warps we can use for N
267
      int max_n_warps = N / kNPerWarp;
268
269
270
271
272
273
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
274
  } else if (this->isSquare()) {
275
    // First calculate the maximum possible warps for each dimension
276
    int max_m_warps =
277
        M / kMPerWarp; // Each warp needs at least 16 elements in M
278
279
280

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
281
282
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
283
284
285
286
287
288
289
290
291
292
293
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
294
295
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
296
297
298
299
300
301
302
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

303
304
305
306
307
308
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
309
310
      }
    }
311
312
313

    m_warp = best_m;
    n_warp = best_n;
314
315
316
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
317
318
319
320
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

321
322
323
324
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

325
326
327
  return {m_warp, n_warp};
}

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
344
345
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
346
347
348
349
350
351
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
352
353
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
354
355
356
357
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
358
359
bool GemmNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
360
361
362
    return false;
  }

363
364
365
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
366
    else if (a_->dtype.is_float8() && b_->dtype.is_float8())
367
      return (!transA_) && transB_ && k_ % 32 == 0;
368
369
    else
      return false;
370
371
372
373
374
375
376
377
378
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
379
    else if (a_->dtype.is_float8() && b_->dtype.is_float8())
380
      return (!transA_) && transB_ && k_ % 32 == 0;
381
382
    else
      return false;
383
384
385
386
387
388
389
390
391
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
392
393
394
395
396
397
398
    else
      return false;
  } else {
    return false;
  }
}

399
400
401
402
403
404
405
406
407
408
409
410
411
412
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
413
414
static int GetArchInt(Target target) {
  int arch_int = 0;
415
416
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
417
418
419
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
420
421
422
423
424
425
  } else {
    arch_int = 0;
  }
  return arch_int;
}

426
427
428
429
430
431
432
433
434
435
436
437
438
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
439
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
440
  auto block_size = *as_const_int(T.thread_bounds->extent);
441
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
442
  auto [warp_m, warp_n] =
443
444
445
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);

  // Build access pointers from regions locally
446
447
448
449
450
451
  PrimExpr Aptr =
      MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Bptr =
      MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Cptr =
      MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
452

453
  std::stringstream ss;
454
455
456
457
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
458
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
459
    ICHECK(can_use_tcgen5mma);
460
461
462
463
    ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
    ICHECK(c_.scope() == "shared.tmem");
    ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (a_.scope() == "shared.tmem") {
464
      op_name = "tl::tcgen5mma_gemm_ts";
465
    } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
466
467
468
469
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
470
          << a_.scope(); // If this is triggered, it means Tilelang has bugs.
471
    }
472
    ICHECK(wgWait_ == -1)
473
474
475
476
477
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
478
479
    if (c_->dtype.is_float()) {
      if (c_->dtype.bits() == 32) {
480
481
482
483
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
484
485
        << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype;
    ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
486
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
487
    ss << transA_ << ", " << transB_ << ", ";
488
489
490
    ss << accum_dtype;
    ss << ">";

491
    auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
492
    Array<PrimExpr> new_args;
493
494
    auto mbarPtr =
        MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
495
496
497
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
498
    new_args.push_back(BufferLoad(C_buffer, cCoords_));
499
    new_args.push_back(mbarPtr);
500
    new_args.push_back(clearAccum_);
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

525
526
527
  if (a_.scope() == "local.fragment") {
    ICHECK(b_.scope() != "local.fragment");
    ICHECK(!transA_)
528
        << "gemm_rs requires the A operand to be in non-transposed layout.";
529
    op_name = "tl::gemm_rs";
530
  } else if (b_.scope() == "local.fragment") {
531
    op_name = "tl::gemm_sr";
532
533
  } else {
    op_name = "tl::gemm_ss";
534
  }
535
  ICHECK(c_.scope() == "local.fragment");
536

537
  ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
538
  ss << warp_m << ", " << warp_n << ", ";
539
540
  ss << transA_ << ", " << transB_;
  auto clear_accum_bool = clearAccum_.as<Bool>();
541
  ICHECK(clear_accum_bool.has_value())
542
      << "clear_accum must be a constant Bool type, got " << clearAccum_;
543
  ss << ", " << bool(clear_accum_bool.value());
544
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
545
546
    ss << ", " << strideA_ << ", " << strideB_;
    ss << ", " << offsetA_ << ", " << offsetB_;
547
  }
548
549
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
550
    ss << ", " << kPack_;
551
  } else if (TargetIsHopper(T.target)) {
552
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
553
  }
554
555
556

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
557
558
    if (wgWait_ != 0) {
      ss << ", " << wgWait_;
559
560
561
562
563
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
564
    ICHECK(wgWait_ == 0 || wgWait_ == -1)
565
566
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
567
    ICHECK(wgWait_ == 0)
568
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
569
  }
570
  ss << ">";
571
572
573

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
574
575
576
  return Evaluate(new_call);
}

577
/**
578
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
579
 *
580
581
582
583
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
584
585
 *
 * Preconditions:
586
 * - C.scope() == "local.fragment"
587
 *
588
589
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
590
591
592
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
593
594
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
595
 */
596
597
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
598
599
  if (completed_)
    return {};
600
  LayoutMap results;
601
602
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
603
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
604
  auto [warp_m, warp_n] =
605
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
606
  if (TargetIsVolta(T.target)) {
607
    ICHECK(c_.scope() == "local.fragment")
608
        << "Volta gemm only supports C in local.fragment scope, got "
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        << c_.scope();
    auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                           c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]),
                                            *as_const_int(a_->shape[dim_A - 1]),
                                            true, !transA_));
    } else if (a_.scope() == "local.fragment") {
      ICHECK(transA_ == false);
      auto fragment =
          makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n);
      results.Set(a_, fragment->BindThreadRange(thread_range));
623
624
625
626
    } else {
      ICHECK(0);
    }

627
628
629
630
631
    ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn");
    int dim_B = b_->shape.size();
    results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]),
                                          *as_const_int(b_->shape[dim_B - 1]),
                                          false, transB_));
632
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
633
634
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
635
636
    ICHECK(c_.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << c_.scope();
637

638
    auto fragment =
639
640
641
642
643
644
645
646
        makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));

    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_,
647
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
648
649
650
651
652
                                   a_->dtype.bits(), !transA_));
    } else if (a_.scope() == "local.fragment") {
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
653
654
655
    } else {
      ICHECK(0);
    }
656
657
658
659
660
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
      results.Set(b_,
661
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
662
663
                                   b_->dtype.bits(), transB_));
    } else if (b_.scope() == "local.fragment") {
664
      auto fragment =
665
666
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
667
668
669
670
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
671
    ICHECK(c_.scope() == "local.fragment")
672
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
673
674
675
676
677
678
679
680
681
682
683
        << "only supports C in local.fragment scope, got " << c_.scope();
    auto fragment = gemm_inst == GemmInst::kWGMMA
                        ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
                                                  n_ / warp_n, c_->dtype.bits())
                        : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                            c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
684
      const int64_t continuity =
685
          transA_ ? 4 * mat_continuous / warp_m : mat_continuous;
686
      auto ABLayout =
687
          gemm_inst == GemmInst::kWGMMA
688
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
689
                                       a_->dtype.bits(), !transA_)
690
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
691
692
                                 a_->dtype.bits(), !transA_);
      results.Set(a_, ABLayout);
693
    } else {
694
695
696
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
697
    }
698
699
700
701
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
702
      const int64_t continuity =
703
          transB_ ? mat_continuous : mat_continuous / warp_n;
704

705
      auto ABLayout =
706
          gemm_inst == GemmInst::kWGMMA
707
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
708
                                       b_->dtype.bits(), transB_)
709
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
710
711
                                 b_->dtype.bits(), transB_);
      results.Set(b_, ABLayout);
712
    } else {
713
      auto fragment =
714
715
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
716
    }
717
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
718
719
720
    ICHECK(c_.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope();
    ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared")
721
722
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
723
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
724
725
    ICHECK(can_use_tcgen5mma);
    {
726
727
728
729
730
731
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                            mat_continuous, a_->dtype.bits(),
                                            transA_ ? 1 : 2));
732
733
    }
    {
734
735
736
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
737
      const int64_t continuity = mat_continuous;
738
      results.Set(b_,
739
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
740
                                        b_->dtype.bits(), transB_ ? 2 : 1));
741
742
743
    }
    {
      Layout res;
744
745
746
      IterVar i = make_itervar("i", m_);
      IterVar j = make_itervar("j", n_);
      ICHECK(m_ % meta.atom_m == 0);
747
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
748
                          FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m);
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
774
      results.Set(c_, res);
775
    }
776
  } else if (TargetIsCDNA(T.target)) {
777
    ICHECK(c_.scope() == "local.fragment")
778
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
779
        << c_.scope();
780
    if (TargetIsDCU(T.target)) {
guchaoyang's avatar
guchaoyang committed
781
782
      auto fragment = makeGemmFragmentCDCU(m_, n_, m_ / warp_m, n_ / warp_n,
                                           c_->dtype.bits());
guchaoyang's avatar
guchaoyang committed
783
      results.Set(c_, fragment->BindThreadRange(thread_range));
Lukinon's avatar
Lukinon committed
784
    } else {
guchaoyang's avatar
guchaoyang committed
785
786
      auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n,
                                            c_->dtype.bits());
guchaoyang's avatar
guchaoyang committed
787
      results.Set(c_, fragment->BindThreadRange(thread_range));
Lukinon's avatar
Lukinon committed
788
    }
789
790
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
791
      auto shared_layout = makeGemmABLayoutCDNA(
792
793
794
795
796
797
798
799
          *as_const_int(a_->shape[dim_A - 2]),
          *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
      results.Set(a_, shared_layout);
    } else if (a_.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                a_->dtype.bits(), kPack_, transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
800
801
802
    } else {
      ICHECK(0);
    }
803
804
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
805
      auto shared_layout = makeGemmABLayoutCDNA(
806
807
          *as_const_int(b_->shape[dim_B - 2]),
          *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
808

809
810
      results.Set(b_, shared_layout);
    } else if (b_.scope() == "local.fragment") {
811
      auto fragment =
812
813
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
814
815
816
817
818
819
820
821
822
823
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

824
TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
825
    .set_num_inputs(5)
826
827
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
828

829
830
831
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

832
TVM_FFI_STATIC_INIT_BLOCK() {
833
834
835
836
837
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
838
                           Target target, GemmInst gemm_inst) {
839
                          policy->computeWarpPartition(M, N, block_size, target,
840
                                                       gemm_inst);
841
                        });
842
}
843

844
} // namespace tl
845
} // namespace tvm