gemm.cc 36 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
 */

#include "gemm.h"
7
#include "builtin.h"
8
#include <fstream>
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14

#include "../target/utils.h"
15
#include "region.h"
16
#include "tcgen5_meta.h"
17
18
19
20
21
22

namespace tvm {
namespace tl {

using namespace tir;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
47
 * @note If `kPack` is provided it must be 1; otherwise the constructor
48
49
50
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
                                            const BufferMap &vmap) {
  // Case 1: Already a BufferRegion
  if (arg->IsInstance<BufferRegionNode>()) {
    return Downcast<BufferRegion>(arg);
  }

  // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
  // extent=1)
  if (const auto *load = arg.as<BufferLoadNode>()) {
    Array<Range> ranges;
    for (const PrimExpr &index : load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
        ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
            << "Only stride-1 Ramp is supported in GEMM region conversion";
        ICHECK(ramp->lanes.as<IntImmNode>())
            << "Scalable vector lanes not supported in GEMM region conversion";
        ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
      } else {
        ranges.push_back(Range::FromMinExtent(index, 1));
      }
    }
    return BufferRegion(load->buffer, ranges);
  }

  // Case 3: Call nodes
  if (const auto *call = arg.as<CallNode>()) {
    // tl.region(...) — reconstruct via RegionOp
    if (call->op.same_as(RegionOp::Get())) {
      RegionOp region(call->args, vmap);
      return BufferRegion(region->GetBuffer(), region->GetRanges());
    }
    // builtin.tvm_access_ptr(...) — map var to Buffer and take full region
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      Var var = Downcast<Var>(call->args[1]);
      Buffer buf = vmap[var];
      Array<Range> ranges;
      for (PrimExpr extent : buf->shape) {
        ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
      }
      return BufferRegion(buf, ranges);
    }
  }

  LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
  throw; // Unreachable, keeps compiler happy
}

// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
                                        int rw_mask) {
  Buffer buf = region->buffer;
  int ndim = static_cast<int>(buf->shape.size());
  ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";

  // Compute row-major strides
  std::vector<PrimExpr> strides(ndim);
  PrimExpr one = make_const(buf->shape[0].dtype(), 1);
  PrimExpr cur = one;
  for (int i = ndim - 1; i >= 0; --i) {
    strides[i] = cur;
    cur = cur * buf->shape[i];
  }

  // Offset: sum_{i in [0..ndim-3]} min_i * stride_i
  PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
  for (int i = 0; i < ndim - 2; ++i) {
    offset = offset + region->region[i]->min * strides[i];
  }

  // Extent: last two extents product (elements)
  PrimExpr extent =
      region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;

  // ptype and return handle
  PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
  Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
                           IntImm(DataType::Int(32), rw_mask)};
  return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}

138
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
139
  ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
159
  if (args.size() > 14) {
160
161
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
162
163
164
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
165
  if (args.size() > 15) {
166
    node->wgWait_ = args[15].as<IntImm>().value()->value;
167
  }
168
169
170
  node->mbarPtr_ = args[16];
  if (node->mbarPtr_.as<CallNode>()) {
    node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
171
  } else {
172
    node->mbar_ = std::nullopt;
173
  }
174
  node->cCoords_ = Array<PrimExpr>(
175
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
176
  data_ = std::move(node);
177
178
}

179
180
181
182
183
184
185
186
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
187
TileOperator GemmNode::Clone() const {
188
  auto op = tvm::ffi::make_object<GemmNode>(*this);
189
190
191
  return Gemm(op);
}

192
bool GemmNode::allowTcgen5Mma(Target target) const {
193
  return TargetIsSm100(target) &&
194
195
196
197
198
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
199
200
}

201
bool GemmNode::allowWgmma(int block_size, Target target) const {
202
203
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

204
205
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
206
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
207
208
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
209
210
}

211
212
GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
  if (allowTcgen5Mma(target)) {
213
    return GemmInst::kTCGEN5MMA;
214
  } else if (allowWgmma(block_size, target)) {
215
216
217
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
218
  } else if (TargetIsCuda(target)) {
219
220
    return GemmInst::kMMA;
  } else {
221
    ICHECK(0) << "Unsupported target for gemm: " << target;
222
    return GemmInst::kMMA;
223
224
225
  }
}

226
std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
227
    int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
228
  int num_warps = block_size / TargetGetWarpSize(target);
229
230
231
232
  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
  }

233
  int m_warp = 1, n_warp = 1;
234
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
235
236
237
238
  int kNPerWarp = 8;            // Columns processed by a single warp
  if (TargetIsVolta(target)) {
    kNPerWarp = 16;
  }
239
240
241
242
243
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

244
  if (gemm_inst == GemmInst::kWGMMA) {
245
246
247
248
249
250
251
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

252
    if (this->isFullRow()) {
253
254
255
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
256
        if (M % (cand * kMPerWarp) == 0) {
257
258
259
260
261
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
262
    } else if (this->isFullCol()) {
263
264
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
265
266
267
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
268
269
270
271
272
273
274
275
276
277
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
278
    } else if (this->isSquare()) {
279
      // Exhaustive search, but m must be multiple of 4
280
281
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
282

283
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
284
285
286
287
288
289
290
291
292
293
294

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

295
296
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
297
298
299
300
301
302
303
304
305
306
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
307
308
309
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
310
311

    ICHECK(m_warp * n_warp == num_warps)
312
313
        << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
        << ", n_warp: " << n_warp << ", num_warps: " << num_warps;
314
315
316
317
318

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

319
320
    return {m_warp, n_warp};
  }
321

322
  if (this->isFullRow()) {
323
    // Try to partition M first
324
    m_warp = num_warps;
325
326
327
328
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
329
    if (M % (m_warp * kMPerWarp) != 0) {
330
      // Calculate how many warps we can use for M
331
      int max_m_warps = M / kMPerWarp;
332
333
334
335
336
337
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
338
  } else if (this->isFullCol()) {
339
340
    // Try to partition N first
    m_warp = 1;
341
    n_warp = num_warps;
342
343
344

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
345
    if (N % (n_warp * kNPerWarp) != 0) {
346
      // Calculate how many warps we can use for N
347
      int max_n_warps = N / kNPerWarp;
348
349
350
351
352
353
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
354
  } else if (this->isSquare()) {
355
    // First calculate the maximum possible warps for each dimension
356
    int max_m_warps =
357
        M / kMPerWarp; // Each warp needs at least 16 elements in M
358
359
360

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
361
362
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
363
364
365
366
367
368
369
370
371
372
373
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
374
375
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
376
377
378
379
380
381
382
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

383
384
385
386
387
388
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
389
390
      }
    }
391
392
393

    m_warp = best_m;
    n_warp = best_n;
394
395
396
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
397
398
399
400
  ICHECK(m_warp * n_warp == num_warps)
      << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
      << ", n_warp: " << n_warp << ", num_warps: " << num_warps;

401
402
403
404
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

405
406
407
  return {m_warp, n_warp};
}

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
424
425
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
426
427
428
429
430
431
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
432
433
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
434
435
436
437
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
438
439
bool GemmNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
440
441
442
    return false;
  }

443
444
445
446
447
448
449
450
451
452
453
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
454
455
    else
      return false;
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
473
474
    else
      return false;
475
476
477
478
479
480
481
482
483
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
484
485
486
487
488
489
490
    else
      return false;
  } else {
    return false;
  }
}

491
492
493
494
495
496
497
498
499
500
501
502
503
504
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
505
506
static int GetArchInt(Target target) {
  int arch_int = 0;
507
508
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
509
510
511
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
512
513
514
515
516
517
  } else {
    arch_int = 0;
  }
  return arch_int;
}

518
519
520
521
522
523
524
525
526
527
528
529
530
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
531
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
532
  auto block_size = *as_const_int(T.thread_bounds->extent);
533
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
534
  auto [warp_m, warp_n] =
535
536
537
538
539
540
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);

  // Build access pointers from regions locally
  PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1);
  PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1);
  PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3);
541

542
  std::stringstream ss;
543
544
545
546
  std::string op_name;

  if (gemm_inst == GemmInst::kTCGEN5MMA) {
    auto [can_use_tcgen5mma, meta] =
547
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
548
    ICHECK(can_use_tcgen5mma);
549
550
551
552
    ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
    ICHECK(c_.scope() == "shared.tmem");
    ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
    if (a_.scope() == "shared.tmem") {
553
      op_name = "tl::tcgen5mma_gemm_ts";
554
    } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
555
556
557
558
      op_name = "tl::tcgen5mma_gemm_ss";
    } else {
      ICHECK(0)
          << "Unsupported A scope for TCGEN5MMA: "
559
          << a_.scope(); // If this is triggered, it means Tilelang has bugs.
560
    }
561
    ICHECK(wgWait_ == -1)
562
563
564
565
566
        << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
           "use "
           "wg_wait = -1 and manually synchronize with mbarrier.";

    std::string accum_dtype = "";
567
568
    if (c_->dtype.is_float()) {
      if (c_->dtype.bits() == 32) {
569
570
571
572
        accum_dtype = "float";
      }
    }
    ICHECK(!accum_dtype.empty())
573
574
        << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype;
    ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
575
    ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
576
    ss << transA_ << ", " << transB_ << ", ";
577
578
579
    ss << accum_dtype;
    ss << ">";

580
    auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
581
582
583
584
    Array<PrimExpr> new_args;
    new_args.push_back(StringImm(ss.str()));
    new_args.push_back(Aptr);
    new_args.push_back(Bptr);
585
586
587
    new_args.push_back(BufferLoad(C_buffer, cCoords_));
    new_args.push_back(mbarPtr_);
    new_args.push_back(clearAccum_);
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);

    // Since TCGEN5MMA atoms provided by CUTLASS always have an internal
    // `elect_one_sync()`, we check if we are calling it using full warps
    constexpr int warp_size = 32;
    ICHECK(
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
        analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
                                0))
        << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
           "and aligned to warps.";
    if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
      // If the thread bounds is exactly one warp, we can use the original call
      return Evaluate(new_call);
    } else {
      // Add an if-else clause
      auto tcgen5mma_call =
          IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
                        FloorDiv(T.thread_bounds->min, warp_size)),
                     Evaluate(new_call));
      return tcgen5mma_call;
    }
  }

612
613
614
  if (a_.scope() == "local.fragment") {
    ICHECK(b_.scope() != "local.fragment");
    ICHECK(!transA_)
615
        << "gemm_rs requires the A operand to be in non-transposed layout.";
616
    op_name = "tl::gemm_rs";
617
  } else if (b_.scope() == "local.fragment") {
618
    op_name = "tl::gemm_sr";
619
620
  } else {
    op_name = "tl::gemm_ss";
621
  }
622
  ICHECK(c_.scope() == "local.fragment");
623

624
  ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
625
  ss << warp_m << ", " << warp_n << ", ";
626
627
  ss << transA_ << ", " << transB_;
  auto clear_accum_bool = clearAccum_.as<Bool>();
628
  ICHECK(clear_accum_bool.has_value())
629
      << "clear_accum must be a constant Bool type, got " << clearAccum_;
630
  ss << ", " << bool(clear_accum_bool.value());
631
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
632
633
    ss << ", " << strideA_ << ", " << strideB_;
    ss << ", " << offsetA_ << ", " << offsetB_;
634
  }
635
636
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
637
    ss << ", " << kPack_;
638
  } else if (TargetIsHopper(T.target)) {
639
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
640
  }
641
642
643

  // Emit wg_wait if necessary
  if (TargetIsHopper(T.target)) {
644
645
    if (wgWait_ != 0) {
      ss << ", " << wgWait_;
646
647
648
649
650
    }
  } else if (TargetIsSm100(T.target)) {
    // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
    // but all threads need to wait, so we emit another statement for cases
    // where wg_wait == 0.
651
    ICHECK(wgWait_ == 0 || wgWait_ == -1)
652
653
        << "wg_wait must be 0 or -1 for Sm100";
  } else {
654
    ICHECK(wgWait_ == 0)
655
        << "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
656
  }
657
  ss << ">";
658
659
660

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
661
662
663
  return Evaluate(new_call);
}

664
/**
665
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
666
 *
667
668
669
670
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
671
672
 *
 * Preconditions:
673
 * - C.scope() == "local.fragment"
674
 *
675
676
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
677
678
679
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
680
681
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
682
 */
683
684
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
685
686
  if (completed_)
    return {};
687
  LayoutMap results;
688
689
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
690
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
691
  auto [warp_m, warp_n] =
692
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
693
  if (TargetIsVolta(T.target)) {
694
    ICHECK(c_.scope() == "local.fragment")
695
        << "Volta gemm only supports C in local.fragment scope, got "
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        << c_.scope();
    auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                           c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]),
                                            *as_const_int(a_->shape[dim_A - 1]),
                                            true, !transA_));
    } else if (a_.scope() == "local.fragment") {
      ICHECK(transA_ == false);
      auto fragment =
          makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n);
      results.Set(a_, fragment->BindThreadRange(thread_range));
710
711
712
713
    } else {
      ICHECK(0);
    }

714
715
716
717
718
    ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn");
    int dim_B = b_->shape.size();
    results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]),
                                          *as_const_int(b_->shape[dim_B - 1]),
                                          false, transB_));
719
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
720
721
             TargetIsSM120(T.target) ||
             (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
722
723
    ICHECK(c_.scope() == "local.fragment")
        << "MMA only supports C in local.fragment scope, got " << c_.scope();
724

725
    auto fragment =
726
727
728
729
730
731
732
733
        makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));

    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_,
734
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
735
736
737
738
739
                                   a_->dtype.bits(), !transA_));
    } else if (a_.scope() == "local.fragment") {
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
740
741
742
    } else {
      ICHECK(0);
    }
743
744
745
746
747
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
      results.Set(b_,
748
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
749
750
                                   b_->dtype.bits(), transB_));
    } else if (b_.scope() == "local.fragment") {
751
      auto fragment =
752
753
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
754
755
756
757
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
758
    ICHECK(c_.scope() == "local.fragment")
759
        << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
760
761
762
763
764
765
766
767
768
769
770
        << "only supports C in local.fragment scope, got " << c_.scope();
    auto fragment = gemm_inst == GemmInst::kWGMMA
                        ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
                                                  n_ / warp_n, c_->dtype.bits())
                        : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
                                            c_->dtype.bits());
    results.Set(c_, fragment->BindThreadRange(thread_range));
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
771
      const int64_t continuity =
772
          transA_ ? 4 * mat_continuous / warp_m : mat_continuous;
773
      auto ABLayout =
774
          gemm_inst == GemmInst::kWGMMA
775
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
776
                                       a_->dtype.bits(), !transA_)
777
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
778
779
                                 a_->dtype.bits(), !transA_);
      results.Set(a_, ABLayout);
780
    } else {
781
782
783
      auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                        a_->dtype.bits(), transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
784
    }
785
786
787
788
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
789
      const int64_t continuity =
790
          transB_ ? mat_continuous : mat_continuous / warp_n;
791

792
      auto ABLayout =
793
          gemm_inst == GemmInst::kWGMMA
794
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
795
                                       b_->dtype.bits(), transB_)
796
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
797
798
                                 b_->dtype.bits(), transB_);
      results.Set(b_, ABLayout);
799
    } else {
800
      auto fragment =
801
802
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
803
    }
804
  } else if (gemm_inst == GemmInst::kTCGEN5MMA) {
805
806
807
    ICHECK(c_.scope() == "shared.tmem")
        << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope();
    ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared")
808
809
        << "Current TCGEN5MMA only supports A in shared.dyn scope";
    auto [can_use_tcgen5mma, meta] =
810
        GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
811
812
    ICHECK(can_use_tcgen5mma);
    {
813
814
815
816
817
818
      int dim_A = a_->shape.size();
      const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
      results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous,
                                            mat_continuous, a_->dtype.bits(),
                                            transA_ ? 1 : 2));
819
820
    }
    {
821
822
823
      int dim_B = b_->shape.size();
      const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
824
      const int64_t continuity = mat_continuous;
825
      results.Set(b_,
826
                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
827
                                        b_->dtype.bits(), transB_ ? 2 : 1));
828
829
830
    }
    {
      Layout res;
831
832
833
      IterVar i = make_itervar("i", m_);
      IterVar j = make_itervar("j", n_);
      ICHECK(m_ % meta.atom_m == 0);
834
      PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
835
                          FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m);
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
      PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
      PrimExpr aj = FloorMod(j, meta.atom_n);
      if (meta.atom_m == 128) {
        // Layout D
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
        res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
      } else if (meta.atom_m == 64) {
        // Layout E
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
        // since .ws variant is used About why we use .ws variant here, please
        // refer to gemm_sm100.h
        res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
                                       FloorDiv(aj, meta.atom_n / 2) * 64,
                                   FloorMod(aj, meta.atom_n / 2) +
                                       atom_idx * (meta.atom_n / 2)});
      } else if (meta.atom_m == 32) {
        // Layout G
        // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
        res = Layout(
            Array{i, j},
            {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
             FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
      } else {
        ICHECK(0);
      }
861
      results.Set(c_, res);
862
    }
863
  } else if (TargetIsCDNA(T.target)) {
864
    ICHECK(c_.scope() == "local.fragment")
865
        << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
866
        << c_.scope();
867
    if (TargetIsDCU(T.target)) {
Lukinon's avatar
Lukinon committed
868
      auto fragment =
guchaoyang's avatar
guchaoyang committed
869
870
          makeGemmFragmentCDCU(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
      results.Set(c_, fragment->BindThreadRange(thread_range));
Lukinon's avatar
Lukinon committed
871
    } else {
872
      auto fragment =
guchaoyang's avatar
guchaoyang committed
873
874
          makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
      results.Set(c_, fragment->BindThreadRange(thread_range));
Lukinon's avatar
Lukinon committed
875
    }
876
877
    if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
      int dim_A = a_->shape.size();
878
      auto shared_layout = makeGemmABLayoutCDNA(
879
880
881
882
883
884
885
886
          *as_const_int(a_->shape[dim_A - 2]),
          *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
      results.Set(a_, shared_layout);
    } else if (a_.scope() == "local.fragment") {
      auto fragment =
          makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
                                a_->dtype.bits(), kPack_, transA_);
      results.Set(a_, fragment->BindThreadRange(thread_range));
887
888
889
    } else {
      ICHECK(0);
    }
890
891
    if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
      int dim_B = b_->shape.size();
892
      auto shared_layout = makeGemmABLayoutCDNA(
893
894
          *as_const_int(b_->shape[dim_B - 2]),
          *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
895

896
897
      results.Set(b_, shared_layout);
    } else if (b_.scope() == "local.fragment") {
898
      auto fragment =
899
900
          makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
      results.Set(b_, fragment->BindThreadRange(thread_range));
901
902
903
904
905
906
907
908
909
910
911
912
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
913
914
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
915

916
917
918
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

919
TVM_FFI_STATIC_INIT_BLOCK() {
920
921
922
923
924
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
925
                           Target target, GemmInst gemm_inst) {
926
                          policy->computeWarpPartition(M, N, block_size, target,
927
                                                       gemm_inst);
928
                        });
929
}
930

931
} // namespace tl
932
} // namespace tvm