"python/sglang/srt/models/internlm2.py" did not exist on "0992d85f92688035cd669d12735518faba93b545"
gemm.cc 32.7 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14

#include "../target/utils.h"
15
#include "tcgen5_meta.h"
16
#include "utils.h"
17
18
19
20
21
22

namespace tvm {
namespace tl {

using namespace tir;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 *
45
 * @note If `kPack` is provided it must be 1; otherwise the constructor
46
47
48
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
49
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
50

51
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
52

53
Gemm::Gemm(Array<PrimExpr> args) {
54
  ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
55

56
57
58
  node->aRegion_ = NormalizeToBufferRegion(args[0]);
  node->bRegion_ = NormalizeToBufferRegion(args[1]);
  node->cRegion_ = NormalizeToBufferRegion(args[2]);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
74
  if (args.size() > 14) {
75
76
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
77
78
79
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
80
  if (args.size() > 15) {
81
    node->wgWait_ = args[15].as<IntImm>().value()->value;
82
  }
83
84
85
86
87
88
89
90
  if (args.size() > 16) {
    if (const auto *load = args[16].as<BufferLoadNode>()) {
      node->mbarRegion_ =
          NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
      node->mbar_ = node->mbarRegion_->buffer;
    } else {
      node->mbar_ = std::nullopt;
    }
91
  }
92
  node->cCoords_ = Array<PrimExpr>(
93
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
94
  data_ = std::move(node);
95
96
}

97
98
99
100
101
102
103
104
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
105
TileOperator GemmNode::Clone() const {
106
  auto op = tvm::ffi::make_object<GemmNode>(*this);
107
108
109
  return Gemm(op);
}

110
bool GemmNode::allowTcgen5Mma(Target target) const {
111
  return TargetIsSm100(target) &&
112
113
114
115
116
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
117
118
}

119
bool GemmNode::allowWgmma(int block_size, Target target) const {
120
121
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

122
123
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
124
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
125
126
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
127
128
}

129
130
GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
  if (allowTcgen5Mma(target)) {
131
    return GemmInst::kTCGEN5MMA;
132
  } else if (allowWgmma(block_size, target)) {
133
134
135
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
136
  } else if (TargetIsCuda(target)) {
137
138
    return GemmInst::kMMA;
  } else {
139
    ICHECK(0) << "Unsupported target for gemm: " << target;
140
    return GemmInst::kMMA;
141
142
143
  }
}

144
std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
145
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
146
  int num_warps = block_size / TargetGetWarpSize(target);
147
148
149
150
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

151
  int m_warp = 1, n_warp = 1;
152
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
153
154
155
156
  int kNPerWarp = 8;            // Columns processed by a single warp
  if (TargetIsVolta(target)) {
    kNPerWarp = 16;
  }
157
158
159
160
161
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

162
  if (gemm_inst == GemmInst::kWGMMA) {
163
164
165
166
167
168
169
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

170
    if (this->isFullRow()) {
171
172
173
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
174
        if (M % (cand * kMPerWarp) == 0) {
175
176
177
178
179
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
180
    } else if (this->isFullCol()) {
181
182
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
183
184
185
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
186
187
188
189
190
191
192
193
194
195
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
196
    } else if (this->isSquare()) {
197
      // Exhaustive search, but m must be multiple of 4
198
199
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
200

201
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
202
203
204
205
206
207
208
209
210
211
212

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

213
214
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
215
216
217
218
219
220
221
222
223
224
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
225
226
227
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
228
229

    ICHECK(m_warp * n_warp == num_warps)
230
231
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
232
233
234
235
236

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

237
238
    return {m_warp, n_warp};
  }
239

240
  if (this->isFullRow()) {
241
    // Try to partition M first
242
    m_warp = num_warps;
243
244
245
246
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
247
    if (M % (m_warp * kMPerWarp) != 0) {
248
      // Calculate how many warps we can use for M
249
      int max_m_warps = M / kMPerWarp;
250
251
252
253
254
255
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
256
  } else if (this->isFullCol()) {
257
258
    // Try to partition N first
    m_warp = 1;
259
    n_warp = num_warps;
260
261
262

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
263
    if (N % (n_warp * kNPerWarp) != 0) {
264
      // Calculate how many warps we can use for N
265
      int max_n_warps = N / kNPerWarp;
266
267
268
269
270
271
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
272
  } else if (this->isSquare()) {
273
    // First calculate the maximum possible warps for each dimension
274
    int max_m_warps =
275
        M / kMPerWarp; // Each warp needs at least 16 elements in M
276
277
278

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
279
280
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
281
282
283
284
285
286
287
288
289
290
291
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
292
293
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
294
295
296
297
298
299
300
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

301
302
303
304
305
306
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
307
308
      }
    }
309
310
311

    m_warp = best_m;
    n_warp = best_n;
312
313
314
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
315
316
317
318
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

319
320
321
322
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

323
324
325
  return {m_warp, n_warp};
}

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
342
343
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
344
345
346
347
348
349
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
350
351
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
352
353
354
355
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
356
357
bool GemmNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
358
359
360
    return false;
  }

361
362
363
364
365
366
367
368
369
370
371
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
372
373
    else
      return false;
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
391
392
    else
      return false;
393
394
395
396
397
398
399
400
401
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
402
403
404
405
406
407
408
    else
      return false;
  } else {
    return false;
  }
}

409
410
411
412
413
414
415
416
417
418
419
420
421
422
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
423
424
static int GetArchInt(Target target) {
  int arch_int = 0;
425
426
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
427
428
429
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
430
431
432
433
434
435
  } else {
    arch_int = 0;
  }
  return arch_int;
}

436
437
438
439
440
441
442
443
444
445
446
447
448
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
449
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
450
  auto block_size = *as_const_int(T.thread_bounds->extent);
451
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
452
  auto [warp_m, warp_n] =
453
454
455
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);

  // Build access pointers from regions locally
456
457
458
459
460
461
  PrimExpr Aptr =
      MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Bptr =
      MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
  PrimExpr Cptr =
      MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
462

463
  std::stringstream ss;
464
465
466
467
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
468
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
469
    ICHECK(can_use_tcgen5mma);
470
471
472
473
    ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
    ICHECK(c_.scope() == "shared.tmem");
    ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (a_.scope() == "shared.tmem") {
474
      op_name = "tl::tcgen5mma_gemm_ts";
475
    } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
476
477
478
479
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
480
          << a_.scope(); // If this is triggered, it means Tilelang has bugs.
481
    }
482
    ICHECK(wgWait_ == -1)
483
484
485
486
487
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
488
489
    if (c_->dtype.is_float()) {
      if (c_->dtype.bits() == 32) {
490
491
492
493
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
494
495
        << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype;
    ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
496
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
497
    ss << transA_ << ", " << transB_ << ", ";
498
499
500
    ss << accum_dtype;
    ss << ">";

501
    auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
502
    Array<PrimExpr> new_args;
503
504
    auto mbarPtr =
        MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
505
506
507
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
508
    new_args.push_back(BufferLoad(C_buffer, cCoords_));
509
    new_args.push_back(mbarPtr);
510
    new_args.push_back(clearAccum_);
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

535
536
537
  if (a_.scope() == "local.fragment") {
    ICHECK(b_.scope() != "local.fragment");
    ICHECK(!transA_)
538
        << "gemm_rs requires the A operand to be in non-transposed layout.";
539
    op_name = "tl::gemm_rs";
540
  } else if (b_.scope() == "local.fragment") {
541
    op_name = "tl::gemm_sr";
542
543
  } else {
    op_name = "tl::gemm_ss";
544
  }
545
  ICHECK(c_.scope() == "local.fragment");
546

547
  ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
548
  ss << warp_m << ", " << warp_n << ", ";
549
550
  ss << transA_ << ", " << transB_;
  auto clear_accum_bool = clearAccum_.as<Bool>();
551
  ICHECK(clear_accum_bool.has_value())
552
      << "clear_accum must be a constant Bool type, got " << clearAccum_;
553
  ss << ", " << bool(clear_accum_bool.value());
554
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
555
556
    ss << ", " << strideA_ << ", " << strideB_;
    ss << ", " << offsetA_ << ", " << offsetB_;
557
  }
558
559
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
560
    ss << ", " << kPack_;
561
  } else if (TargetIsHopper(T.target)) {
562
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
563
  }
564
565
566

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
567
568
    if (wgWait_ != 0) {
      ss << ", " << wgWait_;
569
570
571
572
573
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
574
    ICHECK(wgWait_ == 0 || wgWait_ == -1)
575
576
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
577
    ICHECK(wgWait_ == 0)
578
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
579
  }
580
  ss << ">";
581
582
583

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
584
585
586
  return Evaluate(new_call);
}

587
/**
588
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
589
 *
590
591
592
593
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
594
595
 *
 * Preconditions:
596
 * - C.scope() == "local.fragment"
597
 *
598
599
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
600
601
602
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
603
604
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
605
 */
606
607
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
608
609
  if (completed_)
    return {};
610
  LayoutMap results;
611
612
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
613
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
614
  auto [warp_m, warp_n] =
615
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
616
  if (TargetIsVolta(T.target)) {
617
    ICHECK(c_.scope() == "local.fragment")
618
        << "Volta gemm only supports C in local.fragment scope, got "
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        << c_.scope();
    auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                           c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]),
                                            *as_const_int(a_->shape[dim_A - 1]),
                                            true, !transA_));
    } else if (a_.scope() == "local.fragment") {
      ICHECK(transA_ == false);
      auto fragment =
          makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n);
      results.Set(a_, fragment->BindThreadRange(thread_range));
633
634
635
636
    } else {
      ICHECK(0);
    }

637
638
639
640
641
    ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn");
    int dim_B = b_->shape.size();
    results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]),
                                          *as_const_int(b_->shape[dim_B - 1]),
                                          false, transB_));
642
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
643
644
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
645
646
    ICHECK(c_.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << c_.scope();
647

648
    auto fragment =
649
650
651
652
653
654
655
656
        makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));

    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_,
657
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
658
659
660
661
662
                                   a_->dtype.bits(), !transA_));
    } else if (a_.scope() == "local.fragment") {
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
663
664
665
    } else {
      ICHECK(0);
    }
666
667
668
669
670
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
      results.Set(b_,
671
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
672
673
                                   b_->dtype.bits(), transB_));
    } else if (b_.scope() == "local.fragment") {
674
      auto fragment =
675
676
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
677
678
679
680
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
681
    ICHECK(c_.scope() == "local.fragment")
682
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
683
684
685
686
687
688
689
690
691
692
693
        << "only supports C in local.fragment scope, got " << c_.scope();
    auto fragment = gemm_inst == GemmInst::kWGMMA
                        ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
                                                  n_ / warp_n, c_->dtype.bits())
                        : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                            c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
694
      const int64_t continuity =
695
          transA_ ? 4 * mat_continuous / warp_m : mat_continuous;
696
      auto ABLayout =
697
          gemm_inst == GemmInst::kWGMMA
698
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
699
                                       a_->dtype.bits(), !transA_)
700
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
701
702
                                 a_->dtype.bits(), !transA_);
      results.Set(a_, ABLayout);
703
    } else {
704
705
706
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
707
    }
708
709
710
711
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
712
      const int64_t continuity =
713
          transB_ ? mat_continuous : mat_continuous / warp_n;
714

715
      auto ABLayout =
716
          gemm_inst == GemmInst::kWGMMA
717
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
718
                                       b_->dtype.bits(), transB_)
719
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
720
721
                                 b_->dtype.bits(), transB_);
      results.Set(b_, ABLayout);
722
    } else {
723
      auto fragment =
724
725
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
726
    }
727
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
728
729
730
    ICHECK(c_.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope();
    ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared")
731
732
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
733
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
734
735
    ICHECK(can_use_tcgen5mma);
    {
736
737
738
739
740
741
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                            mat_continuous, a_->dtype.bits(),
                                            transA_ ? 1 : 2));
742
743
    }
    {
744
745
746
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
747
      const int64_t continuity = mat_continuous;
748
      results.Set(b_,
749
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
750
                                        b_->dtype.bits(), transB_ ? 2 : 1));
751
752
753
    }
    {
      Layout res;
754
755
756
      IterVar i = make_itervar("i", m_);
      IterVar j = make_itervar("j", n_);
      ICHECK(m_ % meta.atom_m == 0);
757
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
758
                          FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m);
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
784
      results.Set(c_, res);
785
    }
786
  } else if (TargetIsCDNA(T.target)) {
787
    ICHECK(c_.scope() == "local.fragment")
788
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
789
790
791
792
        << c_.scope();
    auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n,
                                          c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
793

794
795
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
796
      auto shared_layout = makeGemmABLayoutCDNA(
797
798
799
800
801
802
803
804
          *as_const_int(a_->shape[dim_A - 2]),
          *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
      results.Set(a_, shared_layout);
    } else if (a_.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                a_->dtype.bits(), kPack_, transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
805
806
807
    } else {
      ICHECK(0);
    }
808
809
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
810
      auto shared_layout = makeGemmABLayoutCDNA(
811
812
          *as_const_int(b_->shape[dim_B - 2]),
          *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
813

814
815
      results.Set(b_, shared_layout);
    } else if (b_.scope() == "local.fragment") {
816
      auto fragment =
817
818
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
819
820
821
822
823
824
825
826
827
828
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

829
TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
830
    .set_num_inputs(5)
831
832
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
833

834
835
836
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

837
TVM_FFI_STATIC_INIT_BLOCK() {
838
839
840
841
842
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
843
                           Target target, GemmInst gemm_inst) {
844
                          policy->computeWarpPartition(M, N, block_size, target,
845
                                                       gemm_inst);
846
                        });
847
}
848

849
} // namespace tl
850
} // namespace tvm