gemm.cc 23.6 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14
15
16
17
18
19
20

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
49
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
50
51
52
53
54
55
56
57
58
59
60
61
62
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
63
  node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
64
65
66
67
68
  node->clear_accum = args[9].as<Bool>().value();
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
69
  if (args.size() > 14) {
70
71
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
72
73
74
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
75
  if (args.size() > 15) {
76
    node->wg_wait = args[15].as<IntImm>().value()->value;
77
  }
78
  data_ = std::move(node);
79
80
}

81
82
83
84
85
86
87
88
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
89
90
91
92
93
94
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

110
111
112
std::pair<int, int>
GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
                                         Target target, bool use_wgmma) const {
113
  int num_warps = block_size / TargetGetWarpSize(target);
114
  int m_warp = 1, n_warp = 1;
115
116
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
117

118
119
120
121
122
123
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

  if (use_wgmma) {
124
125
126
127
128
129
130
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

131
    if (this->isFullRow()) {
132
133
134
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
135
        if (M % (cand * kMPerWarp) == 0) {
136
137
138
139
140
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
141
    } else if (this->isFullCol()) {
142
143
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
144
145
146
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
147
148
149
150
151
152
153
154
155
156
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
157
    } else if (this->isSquare()) {
158
      // Exhaustive search, but m must be multiple of 4
159
160
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
161

162
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
163
164
165
166
167
168
169
170
171
172
173

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

174
175
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
176
177
178
179
180
181
182
183
184
185
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
186
187
188
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
189
190
191

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
192
193
194
195
196

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

197
198
    return {m_warp, n_warp};
  }
199

200
  if (this->isFullRow()) {
201
    // Try to partition M first
202
    m_warp = num_warps;
203
204
205
206
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
207
    if (M % (m_warp * kMPerWarp) != 0) {
208
      // Calculate how many warps we can use for M
209
      int max_m_warps = M / kMPerWarp;
210
211
212
213
214
215
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
216
  } else if (this->isFullCol()) {
217
218
    // Try to partition N first
    m_warp = 1;
219
    n_warp = num_warps;
220
221
222

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
223
    if (N % (n_warp * kNPerWarp) != 0) {
224
      // Calculate how many warps we can use for N
225
      int max_n_warps = N / kNPerWarp;
226
227
228
229
230
231
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
232
  } else if (this->isSquare()) {
233
    // First calculate the maximum possible warps for each dimension
234
    int max_m_warps =
235
        M / kMPerWarp; // Each warp needs at least 16 elements in M
236
237
238

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
239
240
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
241
242
243
244
245
246
247
248
249
250
251
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();
    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
252
253
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
254
255
256
257
258
259
260
      // m_per_warp and n_per_warp must be greater than 1
      if (m_per_warp < 1 || n_per_warp < 1)
        continue;
      // m * n must equal num_warps
      if (m * n != num_warps)
        continue;

261
262
263
264
265
266
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
267
268
      }
    }
269
270
271

    m_warp = best_m;
    n_warp = best_n;
272
273
274
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
275
276
277
278
  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

279
280
281
  return {m_warp, n_warp};
}

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
298
299
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
300
301
302
303
304
305
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
306
307
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
308
309
310
311
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
312
bool GemmNode::CheckWGMMA() const {
313
314
315
316
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

317
318
319
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
320
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
321
      return (!trans_A) && trans_B && K % 32 == 0;
322
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
323
      return (!trans_A) && trans_B && K % 32 == 0;
324
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
325
      return (!trans_A) && trans_B && K % 32 == 0;
326
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
327
328
329
330
331
332
333
334
335
336
337
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
338
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
339
      return (!trans_A) && trans_B && K % 32 == 0;
340
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
341
      return (!trans_A) && trans_B && K % 32 == 0;
342
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
343
      return (!trans_A) && trans_B && K % 32 == 0;
344
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

364
365
366
367
368
369
370
371
372
373
374
375
376
377
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
378
379
380
381
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
382
383
384
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
385
386
387
388
389
390
  } else {
    arch_int = 0;
  }
  return arch_int;
}

391
392
393
394
395
396
397
398
399
400
401
402
403
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
404
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
405
  auto block_size = *as_const_int(T.thread_bounds->extent);
406
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
407
408
  auto [warp_m, warp_n] = policy->ComputeWarpPartition(
      M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
409

410
411
412
413
414
415
416
417
418
419
420
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
421
  ss << ", " << clear_accum;
422
423
424
425
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
426
427
428
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
429
  } else if (TargetIsHopper(T.target)) {
430
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
431
  }
432
433
434
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
435
  ss << ">";
436
437
438

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
439
440
441
  return Evaluate(new_call);
}

442
/**
443
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
444
 *
445
446
447
448
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
449
450
 *
 * Preconditions:
451
 * - C.scope() == "local.fragment"
452
 *
453
454
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
455
456
457
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
458
459
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
460
 */
461
462
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
463
464
  if (completed_)
    return {};
465
466
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
467
468
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
469
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
470
471
  auto [warp_m, warp_n] = policy->ComputeWarpPartition(
      M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
472

473
  if (TargetIsVolta(T.target)) {
474
475
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
476
    results.Set(C, fragment->BindThreadRange(thread_range));
477
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
478
479
480
481
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
482
483
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
484
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
485
      results.Set(A, fragment->BindThreadRange(thread_range));
486
487
488
489
490
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
491
492
493
494
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
495
496
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
497
498
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
499
    results.Set(C, fragment->BindThreadRange(thread_range));
500
501

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
502
503
504
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
505
506
507
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
508
    } else if (A.scope() == "local.fragment") {
509
510
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
511
      results.Set(A, fragment->BindThreadRange(thread_range));
512
513
514
515
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
516
517
518
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
519
520
521
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
522
    } else if (B.scope() == "local.fragment") {
523
524
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
525
      results.Set(B, fragment->BindThreadRange(thread_range));
526
527
528
529
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
530
    auto fragment =
531
        gemm_inst == GemmInst::kWGMMA
532
533
534
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
535
    results.Set(C, fragment->BindThreadRange(thread_range));
536
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
537
538
539
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
540
      const int64_t continuity =
541
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
542
      auto ABLayout =
543
          gemm_inst == GemmInst::kWGMMA
544
545
546
547
548
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
549
    } else {
550
551
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
552
      results.Set(A, fragment->BindThreadRange(thread_range));
553
554
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
555
556
557
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
558
559
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
560
      auto ABLayout =
561
          gemm_inst == GemmInst::kWGMMA
562
563
564
565
566
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
567
    } else {
568
569
570
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
571
572
    }
  } else if (TargetIsCDNA(T.target)) {
573
574
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
575
    results.Set(C, fragment->BindThreadRange(thread_range));
576
577

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
578
579
580
581
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
582
583
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
584
585
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
586
      results.Set(A, fragment->BindThreadRange(thread_range));
587
588
589
590
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
591
592
593
594
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
595
596
597

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
598
599
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
600
      results.Set(B, fragment->BindThreadRange(thread_range));
601
602
603
604
605
606
607
608
609
610
611
612
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
613
614
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
615

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
TVM_REGISTER_OP("tl.GemmWarpPolicy")
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

TVM_FFI_STATIC_INIT_BLOCK({
  GemmNode::RegisterReflection();
  GemmWarpPolicyNode::RegisterReflection();
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
                        [](GemmWarpPolicy policy, int M, int N, int block_size,
                           Target target, bool is_wgmma) {
                          policy->ComputeWarpPartition(M, N, block_size, target,
                                                       is_wgmma);
                          return;
                        });
});

632
} // namespace tl
633
} // namespace tvm