gemm.cc 23.4 KB
Newer Older
1
2
/*!
 * \file tl/op/gemm.cc
3
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
4
5
6
7
 */

#include "gemm.h"

8
#include "builtin.h"
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/transform.h>
13
14
15
16
17
18
19
20

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
22
23
24
25
26
27
28
29
30
/**
 * @brief Compute the prime factorization of an integer.
 *
 * Returns the prime factors of x in non-decreasing order by repeatedly dividing
 * out the smallest possible factor.
 *
 * @param x Integer to factorize. If x <= 1, an empty vector is returned.
 * @return std::vector<int> Prime factors of x (with multiplicity), in
 * non-decreasing order.
 */
31
32
33
34
35
36
37
38
39
40
41
42
43
44
static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
73
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
74
75
76
77
78
79
80
81
82
83
84
85
86
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
87
  node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
88
89
90
91
92
  node->clear_accum = args[9].as<Bool>().value();
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
93
  if (args.size() > 14) {
94
95
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
96
97
98
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
99
  if (args.size() > 15) {
100
    node->wg_wait = args[15].as<IntImm>().value()->value;
101
  }
102
  data_ = std::move(node);
103
104
}

105
106
107
108
109
110
111
112
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
113
114
115
116
117
118
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

134
135
136
std::pair<int, int>
GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
                                         Target target, bool use_wgmma) const {
137
  int num_warps = block_size / TargetGetWarpSize(target);
138
  int m_warp = 1, n_warp = 1;
139
140
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
141

142
143
144
145
146
147
  ICHECK(M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << M;
  ICHECK(N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << N;

  if (use_wgmma) {
148
149
150
151
152
153
154
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

155
    if (this->isFullRow()) {
156
157
158
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
159
        if (M % (cand * kMPerWarp) == 0) {
160
161
162
163
164
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
165
    } else if (this->isFullCol()) {
166
167
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
168
169
170
      int cand_n = n_warp;                 // Initially assume all on N
      if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = N / kNPerWarp;
171
172
173
174
175
176
177
178
179
180
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
181
    } else if (this->isSquare()) {
182
      // Exhaustive search, but m must be multiple of 4
183
184
      int max_m = M / kMPerWarp;
      int max_n = N / kNPerWarp;
185

186
      float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
187
188
189
190
191
192
193
194
195
196
197

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

198
199
        float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
200
201
202
203
204
205
206
207
208
209
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
210
211
212
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
213
214
215

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
216
217
218
219
220

    // Store the computed values in the object's member variables
    this->m_warp = m_warp;
    this->n_warp = n_warp;

221
222
    return {m_warp, n_warp};
  }
223

224
  if (this->isFullRow()) {
225
    // Try to partition M first
226
    m_warp = num_warps;
227
228
229
230
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
231
    if (M % (m_warp * kMPerWarp) != 0) {
232
      // Calculate how many warps we can use for M
233
      int max_m_warps = M / kMPerWarp;
234
235
236
237
238
239
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
240
  } else if (this->isFullCol()) {
241
242
    // Try to partition N first
    m_warp = 1;
243
    n_warp = num_warps;
244
245
246

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
247
    if (N % (n_warp * kNPerWarp) != 0) {
248
      // Calculate how many warps we can use for N
249
      int max_n_warps = N / kNPerWarp;
250
251
252
253
254
255
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
256
  } else if (this->isSquare()) {
257
    // First calculate the maximum possible warps for each dimension
258
    int max_m_warps =
259
        M / kMPerWarp; // Each warp needs at least 16 elements in M
260
261
262

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
263
264
    if (N > 0) {
      ideal_ratio = static_cast<float>(M) / N;
265
266
267
268
269
270
271
272
273
274
275
276
    }

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
277
278
      float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
279
280
281
282
283
284
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
285
286
      }
    }
287
288
289

    m_warp = best_m;
    n_warp = best_n;
290
291
292
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
293
294
295
296
297

  // Store the computed values in the object's member variables
  this->m_warp = m_warp;
  this->n_warp = n_warp;

298
299
300
  return {m_warp, n_warp};
}

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
317
318
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
319
320
321
322
323
324
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
325
326
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
327
328
329
330
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
331
bool GemmNode::CheckWGMMA() const {
332
333
334
335
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

336
337
338
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
339
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
340
      return (!trans_A) && trans_B && K % 32 == 0;
341
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
342
      return (!trans_A) && trans_B && K % 32 == 0;
343
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
344
      return (!trans_A) && trans_B && K % 32 == 0;
345
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
346
347
348
349
350
351
352
353
354
355
356
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
357
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
358
      return (!trans_A) && trans_B && K % 32 == 0;
359
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
360
      return (!trans_A) && trans_B && K % 32 == 0;
361
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
362
      return (!trans_A) && trans_B && K % 32 == 0;
363
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

383
384
385
386
387
388
389
390
391
392
393
394
395
396
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
397
398
399
400
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
401
402
403
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
404
405
406
407
408
409
  } else {
    arch_int = 0;
  }
  return arch_int;
}

410
411
412
413
414
415
416
417
418
419
420
421
422
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
423
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
424
  auto block_size = *as_const_int(T.thread_bounds->extent);
425
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
426
427
  auto [warp_m, warp_n] = policy->ComputeWarpPartition(
      M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
428

429
430
431
432
433
434
435
436
437
438
439
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
440
  ss << ", " << clear_accum;
441
442
443
444
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
445
446
447
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
448
  } else if (TargetIsHopper(T.target)) {
449
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
450
  }
451
452
453
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
454
  ss << ">";
455
456
457

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
458
459
460
  return Evaluate(new_call);
}

461
/**
462
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
463
 *
464
465
466
467
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
468
469
 *
 * Preconditions:
470
 * - C.scope() == "local.fragment"
471
 *
472
473
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
474
475
476
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
477
478
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
479
 */
480
481
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
482
483
  if (completed_)
    return {};
484
485
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
486
487
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
488
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
489
490
  auto [warp_m, warp_n] = policy->ComputeWarpPartition(
      M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
491

492
  if (TargetIsVolta(T.target)) {
493
494
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
495
    results.Set(C, fragment->BindThreadRange(thread_range));
496
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
497
498
499
500
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
501
502
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
503
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
504
      results.Set(A, fragment->BindThreadRange(thread_range));
505
506
507
508
509
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
510
511
512
513
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
514
515
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
516
517
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
518
    results.Set(C, fragment->BindThreadRange(thread_range));
519
520

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
521
522
523
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
524
525
526
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
527
    } else if (A.scope() == "local.fragment") {
528
529
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
530
      results.Set(A, fragment->BindThreadRange(thread_range));
531
532
533
534
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
535
536
537
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
538
539
540
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
541
    } else if (B.scope() == "local.fragment") {
542
543
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
544
      results.Set(B, fragment->BindThreadRange(thread_range));
545
546
547
548
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
549
    auto fragment =
550
        gemm_inst == GemmInst::kWGMMA
551
552
553
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
554
    results.Set(C, fragment->BindThreadRange(thread_range));
555
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
556
557
558
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
559
      const int64_t continuity =
560
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
561
      auto ABLayout =
562
          gemm_inst == GemmInst::kWGMMA
563
564
565
566
567
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
568
    } else {
569
570
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
571
      results.Set(A, fragment->BindThreadRange(thread_range));
572
573
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
574
575
576
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
577
578
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
579
      auto ABLayout =
580
          gemm_inst == GemmInst::kWGMMA
581
582
583
584
585
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
586
    } else {
587
588
589
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
590
591
    }
  } else if (TargetIsCDNA(T.target)) {
592
593
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
594
    results.Set(C, fragment->BindThreadRange(thread_range));
595
596

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
597
598
599
600
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
601
602
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
603
604
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
605
      results.Set(A, fragment->BindThreadRange(thread_range));
606
607
608
609
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
610
611
612
613
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
614
615
616

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
617
618
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
619
      results.Set(B, fragment->BindThreadRange(thread_range));
620
621
622
623
624
625
626
627
628
629
630
631
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
632
633
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
634

635
} // namespace tl
636
} // namespace tvm