gemm.cc 23 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
  node->policy =
      static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
  node->clear_accum = args[9].as<Bool>().value();
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
57
  if (args.size() > 14) {
58
59
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
60
61
62
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
63
  if (args.size() > 15) {
64
    node->wg_wait = args[15].as<IntImm>().value()->value;
65
  }
66
  data_ = std::move(node);
67
68
}

69
70
71
72
73
74
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

90
91
92
93
94
95
96
97
98
/**
 * @brief Compute how warps are partitioned between the M and N GEMM dimensions.
 *
 * Determines the number of warps assigned to the M (rows) and N (columns)
 * dimensions for a block given the selected GEMM implementation and target.
 * The function enforces constraints required by the implementations (e.g.,
 * per-warp tile sizes) and adapts the partition according to the configured
 * GemmWarpPolicy (FullRow, FullCol, Square).
 *
99
100
 * @param block_size Total number of threads in the block (used to derive
 * num_warps).
101
 * @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
102
103
104
105
 * @param target Target device information (used for warp size and
 * target-specific rules).
 * @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp ==
 * num_warps.
106
107
108
109
110
111
112
113
114
 *
 * Constraints and behavior:
 * - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
 *   checks that M % 16 == 0 and N % 8 == 0.
 * - num_warps is computed as block_size / warp_size(target).
 * - For WGMMA (kWGMMA):
 *   - num_warps must be a multiple of 4 (warp-groups of 4).
 *   - m_warp is always a multiple of 4.
 *   - The warp partition respects the GemmWarpPolicy:
115
116
 *     - FullRow: maximize warps on M (in multiples of 4) while keeping
 * divisibility.
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
 *     - FullCol: maximize warps on N, but if N is not evenly divisible, move
 *       whole warp-groups to M to achieve feasibility.
 *     - Square: choose a multiple-of-4 m_warp that best balances per-warp work
 *       between M and N.
 * - For non-WGMMA implementations:
 *   - FullRow: favor allocating warps to M first; if M cannot use all warps,
 *     remaining warps are placed on N.
 *   - FullCol: favor allocating warps to N first; if N cannot use all warps,
 *     remaining warps are placed on M.
 *   - Square: search for the m/n split that best balances per-warp work given
 *     integer warp counts and the per-warp tile sizes.
 *
 * Error handling:
 * - The function performs internal checks (ICHECK) and will fail if required
 *   divisibility or policy conditions are not met (e.g., M/N tile divisibility,
 *   invalid policy, or WGMMA-specific warp-group requirements).
 */
134
135
136
std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
                                                   GemmInst gemm_inst,
                                                   Target target) const {
137
  int num_warps = block_size / TargetGetWarpSize(target);
138
  int m_warp = 1, n_warp = 1;
139
140
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
141

142
143
144
145
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
146
  if (gemm_inst == GemmInst::kWGMMA) {
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
164
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
209
210
211
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
212
213
214

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
215
216
    return {m_warp, n_warp};
  }
217

218
  if (this->policy == GemmWarpPolicy::kFullRow) {
219
    // Try to partition M first
220
    m_warp = num_warps;
221
222
223
224
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
225
    if (this->M % (m_warp * kMPerWarp) != 0) {
226
      // Calculate how many warps we can use for M
227
      int max_m_warps = this->M / kMPerWarp;
228
229
230
231
232
233
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
234
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
235
236
    // Try to partition N first
    m_warp = 1;
237
    n_warp = num_warps;
238
239
240

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
241
    if (this->N % (n_warp * kNPerWarp) != 0) {
242
      // Calculate how many warps we can use for N
243
      int max_n_warps = this->N / kNPerWarp;
244
245
246
247
248
249
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
250
  } else if (this->policy == GemmWarpPolicy::kSquare) {
251
    // First calculate the maximum possible warps for each dimension
252
253
254
255
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
277
278
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
279
280
281
282
283
284
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
285
286
      }
    }
287
288
289

    m_warp = best_m;
    n_warp = best_n;
290
291
292
293
294
295
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
312
313
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
314
315
316
317
318
319
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
320
321
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
322
323
324
325
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
326
bool GemmNode::CheckWGMMA() const {
327
328
329
330
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

331
332
333
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
334
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
335
      return (!trans_A) && trans_B && K % 32 == 0;
336
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
337
      return (!trans_A) && trans_B && K % 32 == 0;
338
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
339
      return (!trans_A) && trans_B && K % 32 == 0;
340
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
341
342
343
344
345
346
347
348
349
350
351
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
352
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
353
      return (!trans_A) && trans_B && K % 32 == 0;
354
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
355
      return (!trans_A) && trans_B && K % 32 == 0;
356
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
357
      return (!trans_A) && trans_B && K % 32 == 0;
358
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

378
379
380
381
382
383
384
385
386
387
388
389
390
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

391
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
392
  auto block_size = *as_const_int(T.thread_bounds->extent);
393
394
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
395

396
397
398
399
400
401
402
403
404
405
406
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
407
  ss << ", " << clear_accum;
408
409
410
411
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
412
413
414
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
415
  } else if (TargetIsHopper(T.target)) {
416
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
417
  }
418
419
420
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
421
  ss << ">";
422
423
424

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
425
426
427
  return Evaluate(new_call);
}

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
/**
 * @brief Infer memory/layout mappings for A, B, and C buffers for this GEMM op.
 *
 * Generates and returns a LayoutMap that binds buffer A, B, and C to
 * target- and architecture-specific fragment or shared-memory layouts based
 * on the current target, thread bounds, warp partitioning, data types, and
 * transpose flags. This performs target dispatch (Volta, Ampere/Turing/SM120,
 * Hopper, CDNA), selects the appropriate fragment or shared layout creators,
 * and binds fragment layouts to the thread range when buffers are local
 * fragments.
 *
 * Preconditions:
 * - C.scope() must be "local.fragment".
 *
 * Postconditions / side effects:
443
444
 * - Marks the operator's layout inference as completed (sets completed_ =
 * true).
445
446
447
448
449
450
451
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
 * @param T Layout inference inputs (thread bounds and target).
 * @param level Inference level (unused for side effects but retained for API).
 * @return LayoutMap mapping each of A, B, and C to their inferred layouts.
 */
452
453
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
454
455
  if (completed_)
    return {};
456
457
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
458
459
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
460
461
462
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

463
  if (TargetIsVolta(T.target)) {
464
465
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
466
    results.Set(C, fragment->BindThreadRange(thread_range));
467
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
468
469
470
471
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
472
473
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
474
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
475
      results.Set(A, fragment->BindThreadRange(thread_range));
476
477
478
479
480
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
481
482
483
484
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
485
486
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
487
488
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
489
    results.Set(C, fragment->BindThreadRange(thread_range));
490
491

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
492
493
494
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
495
496
497
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
498
    } else if (A.scope() == "local.fragment") {
499
500
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
501
      results.Set(A, fragment->BindThreadRange(thread_range));
502
503
504
505
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
506
507
508
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
509
510
511
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
512
    } else if (B.scope() == "local.fragment") {
513
514
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
515
      results.Set(B, fragment->BindThreadRange(thread_range));
516
517
518
519
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
520
    auto fragment =
521
        gemm_inst == GemmInst::kWGMMA
522
523
524
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
525
    results.Set(C, fragment->BindThreadRange(thread_range));
526
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
527
528
529
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
530
      const int64_t continuity =
531
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
532
      auto ABLayout =
533
          gemm_inst == GemmInst::kWGMMA
534
535
536
537
538
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
539
    } else {
540
541
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
542
      results.Set(A, fragment->BindThreadRange(thread_range));
543
544
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
545
546
547
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
548
549
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
550
      auto ABLayout =
551
          gemm_inst == GemmInst::kWGMMA
552
553
554
555
556
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
557
    } else {
558
559
560
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
561
562
    }
  } else if (TargetIsCDNA(T.target)) {
563
564
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
565
    results.Set(C, fragment->BindThreadRange(thread_range));
566
567

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
568
569
570
571
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
572
573
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
574
575
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
576
      results.Set(A, fragment->BindThreadRange(thread_range));
577
578
579
580
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
581
582
583
584
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
585
586
587

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
588
589
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
590
      results.Set(B, fragment->BindThreadRange(thread_range));
591
592
593
594
595
596
597
598
599
600
601
602
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
603
604
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
605

606
} // namespace tl
607
} // namespace tvm