gemm.cc 22.6 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
  clear_accum = args[9].as<Bool>().value();
50
51
52
53
54
55
  stride_A = args[10].as<IntImm>().value()->value;
  stride_B = args[11].as<IntImm>().value()->value;
  offset_A = args[12].as<IntImm>().value()->value;
  offset_B = args[13].as<IntImm>().value()->value;
  if (args.size() > 14) {
    kPack = args[14].as<IntImm>().value()->value;
56
57
58
59
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
60
61
  if (args.size() > 15) {
    wg_wait = args[15].as<IntImm>().value()->value;
62
  }
63
64
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * @brief Compute how warps are partitioned between the M and N GEMM dimensions.
 *
 * Determines the number of warps assigned to the M (rows) and N (columns)
 * dimensions for a block given the selected GEMM implementation and target.
 * The function enforces constraints required by the implementations (e.g.,
 * per-warp tile sizes) and adapts the partition according to the configured
 * GemmWarpPolicy (FullRow, FullCol, Square).
 *
 * @param block_size Total number of threads in the block (used to derive num_warps).
 * @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
 * @param target Target device information (used for warp size and target-specific rules).
 * @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp == num_warps.
 *
 * Constraints and behavior:
 * - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
 *   checks that M % 16 == 0 and N % 8 == 0.
 * - num_warps is computed as block_size / warp_size(target).
 * - For WGMMA (kWGMMA):
 *   - num_warps must be a multiple of 4 (warp-groups of 4).
 *   - m_warp is always a multiple of 4.
 *   - The warp partition respects the GemmWarpPolicy:
 *     - FullRow: maximize warps on M (in multiples of 4) while keeping divisibility.
 *     - FullCol: maximize warps on N, but if N is not evenly divisible, move
 *       whole warp-groups to M to achieve feasibility.
 *     - Square: choose a multiple-of-4 m_warp that best balances per-warp work
 *       between M and N.
 * - For non-WGMMA implementations:
 *   - FullRow: favor allocating warps to M first; if M cannot use all warps,
 *     remaining warps are placed on N.
 *   - FullCol: favor allocating warps to N first; if N cannot use all warps,
 *     remaining warps are placed on M.
 *   - Square: search for the m/n split that best balances per-warp work given
 *     integer warp counts and the per-warp tile sizes.
 *
 * Error handling:
 * - The function performs internal checks (ICHECK) and will fail if required
 *   divisibility or policy conditions are not met (e.g., M/N tile divisibility,
 *   invalid policy, or WGMMA-specific warp-group requirements).
 */
121
122
123
124
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
125
  int m_warp = 1, n_warp = 1;
126
127
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
128

129
130
131
132
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
133
  if (gemm_inst == GemmInst::kWGMMA) {
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
151
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
196
197
198
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
199
200
201

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
202
203
    return {m_warp, n_warp};
  }
204

205
  if (this->policy == GemmWarpPolicy::kFullRow) {
206
    // Try to partition M first
207
    m_warp = num_warps;
208
209
210
211
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
212
    if (this->M % (m_warp * kMPerWarp) != 0) {
213
      // Calculate how many warps we can use for M
214
      int max_m_warps = this->M / kMPerWarp;
215
216
217
218
219
220
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
221
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
222
223
    // Try to partition N first
    m_warp = 1;
224
    n_warp = num_warps;
225
226
227

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
228
    if (this->N % (n_warp * kNPerWarp) != 0) {
229
      // Calculate how many warps we can use for N
230
      int max_n_warps = this->N / kNPerWarp;
231
232
233
234
235
236
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
237
  } else if (this->policy == GemmWarpPolicy::kSquare) {
238
    // First calculate the maximum possible warps for each dimension
239
240
241
242
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
264
265
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
266
267
268
269
270
271
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
272
273
      }
    }
274
275
276

    m_warp = best_m;
    n_warp = best_n;
277
278
279
280
281
282
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % 32 == 0
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) and K % 32 == 0
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
311
bool Gemm::CheckWGMMA() const {
312
313
314
315
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

316
317
318
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
319
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
320
      return (!trans_A) && trans_B && K % 32 == 0;
321
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
322
      return (!trans_A) && trans_B && K % 32 == 0;
323
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
324
      return (!trans_A) && trans_B && K % 32 == 0;
325
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
326
327
328
329
330
331
332
333
334
335
336
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
337
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
338
      return (!trans_A) && trans_B && K % 32 == 0;
339
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
340
      return (!trans_A) && trans_B && K % 32 == 0;
341
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
342
      return (!trans_A) && trans_B && K % 32 == 0;
343
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

363
364
365
366
367
368
369
370
371
372
373
374
375
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

376
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
377
  auto block_size = *as_const_int(T.thread_bounds->extent);
378
379
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
380

381
382
383
384
385
386
387
388
389
390
391
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
392
  ss << ", " << clear_accum;
393
394
395
396
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
397
398
399
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
400
  } else if (TargetIsHopper(T.target)) {
401
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
402
  }
403
404
405
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
406
  ss << ">";
407
408
409

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
410
411
412
  return Evaluate(new_call);
}

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
/**
 * @brief Infer memory/layout mappings for A, B, and C buffers for this GEMM op.
 *
 * Generates and returns a LayoutMap that binds buffer A, B, and C to
 * target- and architecture-specific fragment or shared-memory layouts based
 * on the current target, thread bounds, warp partitioning, data types, and
 * transpose flags. This performs target dispatch (Volta, Ampere/Turing/SM120,
 * Hopper, CDNA), selects the appropriate fragment or shared layout creators,
 * and binds fragment layouts to the thread range when buffers are local
 * fragments.
 *
 * Preconditions:
 * - C.scope() must be "local.fragment".
 *
 * Postconditions / side effects:
 * - Marks the operator's layout inference as completed (sets completed_ = true).
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
 * @param T Layout inference inputs (thread bounds and target).
 * @param level Inference level (unused for side effects but retained for API).
 * @return LayoutMap mapping each of A, B, and C to their inferred layouts.
 */
436
437
438
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
439
440
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
441
442
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
443
444
445
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

446
  if (TargetIsVolta(T.target)) {
447
448
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
449
    results.Set(C, fragment->BindThreadRange(thread_range));
450
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
451
452
453
454
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
455
456
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
457
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
458
      results.Set(A, fragment->BindThreadRange(thread_range));
459
460
461
462
463
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
464
465
466
467
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
468
469
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
470
471
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
472
    results.Set(C, fragment->BindThreadRange(thread_range));
473
474

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
475
476
477
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
478
479
480
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
481
    } else if (A.scope() == "local.fragment") {
482
483
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
484
      results.Set(A, fragment->BindThreadRange(thread_range));
485
486
487
488
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
489
490
491
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
492
493
494
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
495
    } else if (B.scope() == "local.fragment") {
496
497
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
498
      results.Set(B, fragment->BindThreadRange(thread_range));
499
500
501
502
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
503
    auto fragment =
504
        gemm_inst == GemmInst::kWGMMA
505
506
507
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
508
    results.Set(C, fragment->BindThreadRange(thread_range));
509
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
510
511
512
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
513
      const int64_t continuity =
514
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
515
      auto ABLayout =
516
          gemm_inst == GemmInst::kWGMMA
517
518
519
520
521
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
522
    } else {
523
524
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
525
      results.Set(A, fragment->BindThreadRange(thread_range));
526
527
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
528
529
530
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
531
532
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
533
      auto ABLayout =
534
          gemm_inst == GemmInst::kWGMMA
535
536
537
538
539
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
540
    } else {
541
542
543
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
544
545
    }
  } else if (TargetIsCDNA(T.target)) {
546
547
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
548
    results.Set(C, fragment->BindThreadRange(thread_range));
549
550

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
551
552
553
554
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
555
556
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
557
558
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
559
      results.Set(A, fragment->BindThreadRange(thread_range));
560
561
562
563
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
564
565
566
567
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
568
569
570

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
571
572
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
573
      results.Set(B, fragment->BindThreadRange(thread_range));
574
575
576
577
578
579
580
581
582
583
584
585
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
586
587
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
588

589
} // namespace tl
590
} // namespace tvm