gemm.cc 26.5 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

22
23
24
25
26
27
28
29
30
31
/**
 * @brief Compute the prime factorization of an integer.
 *
 * Returns the prime factors of x in non-decreasing order by repeatedly dividing
 * out the smallest possible factor.
 *
 * @param x Integer to factorize. If x <= 1, an empty vector is returned.
 * @return std::vector<int> Prime factors of x (with multiplicity), in
 * non-decreasing order.
 */
32
33
34
35
36
37
38
39
40
41
42
43
44
45
static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
74
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  ObjectPtr<GemmNode> node = make_object<GemmNode>();

  node->Aptr = args[0];
  node->Bptr = args[1];
  node->Cptr = args[2];
  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
  node->trans_A = args[3].as<Bool>().value();
  node->trans_B = args[4].as<Bool>().value();
  node->M = args[5].as<IntImm>().value()->value;
  node->N = args[6].as<IntImm>().value()->value;
  node->K = args[7].as<IntImm>().value()->value;
  node->policy =
      static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
  node->clear_accum = args[9].as<Bool>().value();
  node->stride_A = args[10].as<IntImm>().value()->value;
  node->stride_B = args[11].as<IntImm>().value()->value;
  node->offset_A = args[12].as<IntImm>().value()->value;
  node->offset_B = args[13].as<IntImm>().value()->value;
95
  if (args.size() > 14) {
96
97
    node->kPack = args[14].as<IntImm>().value()->value;
    if (node->kPack != 1 && node->kPack != 2) {
98
99
100
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
101
  if (args.size() > 15) {
102
    node->wg_wait = args[15].as<IntImm>().value()->value;
103
  }
104
  data_ = std::move(node);
105
106
}

107
108
109
110
111
112
113
114
/**
 * @brief Create a copy of this GemmNode as a TileOperator.
 *
 * Constructs a new GemmNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
115
116
117
118
119
TileOperator GemmNode::Clone() const {
  auto op = make_object<GemmNode>(*this);
  return Gemm(op);
}

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/**
 * @brief Selects the GEMM implementation variant for a given block size and
 * target.
 *
 * Determines which low-level GEMM instruction to use:
 * - Returns kWGMMA when running on Hopper-class targets and the operator meets
 *   WGMMA constraints (M >= 64, number of warps is a multiple of 4, and
 *   CheckWGMMA() returns true).
 * - Returns kMFMA for CDNA targets.
 * - Returns kMMA for CUDA targets.
 *
 * @param block_size Number of threads in the CUDA/ROCm thread block used for
 * the GEMM.
 * @param target Target backend describing the hardware (used to detect
 * architecture).
 * @return GemmInst The chosen GEMM implementation enum value.
 *
 * @throws fatal error (ICHECK) If the target is not recognized/supported, this
 * function triggers a runtime check failure.
 */
140
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

156
157
158
159
160
161
162
163
164
/**
 * @brief Compute how warps are partitioned between the M and N GEMM dimensions.
 *
 * Determines the number of warps assigned to the M (rows) and N (columns)
 * dimensions for a block given the selected GEMM implementation and target.
 * The function enforces constraints required by the implementations (e.g.,
 * per-warp tile sizes) and adapts the partition according to the configured
 * GemmWarpPolicy (FullRow, FullCol, Square).
 *
165
166
 * @param block_size Total number of threads in the block (used to derive
 * num_warps).
167
 * @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
168
169
170
171
 * @param target Target device information (used for warp size and
 * target-specific rules).
 * @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp ==
 * num_warps.
172
173
174
175
176
177
178
179
180
 *
 * Constraints and behavior:
 * - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
 *   checks that M % 16 == 0 and N % 8 == 0.
 * - num_warps is computed as block_size / warp_size(target).
 * - For WGMMA (kWGMMA):
 *   - num_warps must be a multiple of 4 (warp-groups of 4).
 *   - m_warp is always a multiple of 4.
 *   - The warp partition respects the GemmWarpPolicy:
181
182
 *     - FullRow: maximize warps on M (in multiples of 4) while keeping
 * divisibility.
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
 *     - FullCol: maximize warps on N, but if N is not evenly divisible, move
 *       whole warp-groups to M to achieve feasibility.
 *     - Square: choose a multiple-of-4 m_warp that best balances per-warp work
 *       between M and N.
 * - For non-WGMMA implementations:
 *   - FullRow: favor allocating warps to M first; if M cannot use all warps,
 *     remaining warps are placed on N.
 *   - FullCol: favor allocating warps to N first; if N cannot use all warps,
 *     remaining warps are placed on M.
 *   - Square: search for the m/n split that best balances per-warp work given
 *     integer warp counts and the per-warp tile sizes.
 *
 * Error handling:
 * - The function performs internal checks (ICHECK) and will fail if required
 *   divisibility or policy conditions are not met (e.g., M/N tile divisibility,
 *   invalid policy, or WGMMA-specific warp-group requirements).
 */
200
201
202
std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
                                                   GemmInst gemm_inst,
                                                   Target target) const {
203
  int num_warps = block_size / TargetGetWarpSize(target);
204
  int m_warp = 1, n_warp = 1;
205
206
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
207

208
209
210
211
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
212
  if (gemm_inst == GemmInst::kWGMMA) {
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
230
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
275
276
277
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
278
279
280

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
281
282
    return {m_warp, n_warp};
  }
283

284
  if (this->policy == GemmWarpPolicy::kFullRow) {
285
    // Try to partition M first
286
    m_warp = num_warps;
287
288
289
290
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
291
    if (this->M % (m_warp * kMPerWarp) != 0) {
292
      // Calculate how many warps we can use for M
293
      int max_m_warps = this->M / kMPerWarp;
294
295
296
297
298
299
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
300
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
301
302
    // Try to partition N first
    m_warp = 1;
303
    n_warp = num_warps;
304
305
306

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
307
    if (this->N % (n_warp * kNPerWarp) != 0) {
308
      // Calculate how many warps we can use for N
309
      int max_n_warps = this->N / kNPerWarp;
310
311
312
313
314
315
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
316
  } else if (this->policy == GemmWarpPolicy::kSquare) {
317
    // First calculate the maximum possible warps for each dimension
318
319
320
321
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
343
344
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
345
346
347
348
349
350
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
351
352
      }
    }
353
354
355

    m_warp = best_m;
    n_warp = best_n;
356
357
358
359
360
361
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
378
379
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
380
381
382
383
384
385
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
386
387
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
388
389
390
391
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
392
bool GemmNode::CheckWGMMA() const {
393
394
395
396
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

397
398
399
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
400
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
401
      return (!trans_A) && trans_B && K % 32 == 0;
402
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
403
      return (!trans_A) && trans_B && K % 32 == 0;
404
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
405
      return (!trans_A) && trans_B && K % 32 == 0;
406
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
407
408
409
410
411
412
413
414
415
416
417
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
418
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
419
      return (!trans_A) && trans_B && K % 32 == 0;
420
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
421
      return (!trans_A) && trans_B && K % 32 == 0;
422
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
423
      return (!trans_A) && trans_B && K % 32 == 0;
424
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

444
445
446
447
448
449
450
451
452
453
454
455
456
457
/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
458
459
460
461
462
463
464
465
466
467
468
469
470
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

471
472
473
474
475
476
477
478
479
480
481
482
483
/**
 * @brief Lower the GEMM operator to a TL TIR call expression.
 *
 * Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
 * transpose flags, accumulation clearing, target-specific stride/offset/kPack
 * and optional workgroup wait value, then returns an Evaluate(call) node
 * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
 *
 * @param T Contains lowering context including thread bounds and target.
 * @param analyzer Optional arithmetic analyzer used by lowering (may be
 * nullptr).
 * @return Stmt A TIR statement representing the evaluated TL GEMM call.
 */
484
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
485
  auto block_size = *as_const_int(T.thread_bounds->extent);
486
487
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
488

489
490
491
492
493
494
495
496
497
498
499
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
500
  ss << ", " << clear_accum;
501
502
503
504
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
505
506
507
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
508
  } else if (TargetIsHopper(T.target)) {
509
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
510
  }
511
512
513
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
514
  ss << ">";
515
516
517

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
518
519
520
  return Evaluate(new_call);
}

521
/**
522
 * @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
523
 *
524
525
526
527
 * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
 * operator according to the target architecture, thread bounds, warp
 * partitioning, data types, and transpose flags, then binds fragment layouts
 * to the thread range when required.
528
529
 *
 * Preconditions:
530
 * - C.scope() == "local.fragment"
531
 *
532
533
 * Side effects:
 * - Marks layout inference as completed (sets completed_ = true).
534
535
536
 * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
 *   incompatible shape constraints.
 *
537
538
 * @param T Input layout-inference context (provides thread bounds and target).
 * @return LayoutMap mapping A, B, and C to their inferred layouts.
539
 */
540
541
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
542
543
  if (completed_)
    return {};
544
545
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
546
547
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
548
549
550
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

551
  if (TargetIsVolta(T.target)) {
552
553
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
554
    results.Set(C, fragment->BindThreadRange(thread_range));
555
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
556
557
558
559
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
560
561
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
562
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
563
      results.Set(A, fragment->BindThreadRange(thread_range));
564
565
566
567
568
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
569
570
571
572
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
573
574
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
575
576
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
577
    results.Set(C, fragment->BindThreadRange(thread_range));
578
579

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
580
581
582
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
583
584
585
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
586
    } else if (A.scope() == "local.fragment") {
587
588
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
589
      results.Set(A, fragment->BindThreadRange(thread_range));
590
591
592
593
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
594
595
596
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
597
598
599
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
600
    } else if (B.scope() == "local.fragment") {
601
602
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
603
      results.Set(B, fragment->BindThreadRange(thread_range));
604
605
606
607
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
608
    auto fragment =
609
        gemm_inst == GemmInst::kWGMMA
610
611
612
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
613
    results.Set(C, fragment->BindThreadRange(thread_range));
614
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
615
616
617
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
618
      const int64_t continuity =
619
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
620
      auto ABLayout =
621
          gemm_inst == GemmInst::kWGMMA
622
623
624
625
626
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
627
    } else {
628
629
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
630
      results.Set(A, fragment->BindThreadRange(thread_range));
631
632
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
633
634
635
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
636
637
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
638
      auto ABLayout =
639
          gemm_inst == GemmInst::kWGMMA
640
641
642
643
644
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
645
    } else {
646
647
648
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
649
650
    }
  } else if (TargetIsCDNA(T.target)) {
651
652
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
653
    results.Set(C, fragment->BindThreadRange(thread_range));
654
655

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
656
657
658
659
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
660
661
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
662
663
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
664
      results.Set(A, fragment->BindThreadRange(thread_range));
665
666
667
668
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
669
670
671
672
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
673
674
675

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
676
677
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
678
      results.Set(B, fragment->BindThreadRange(thread_range));
679
680
681
682
683
684
685
686
687
688
689
690
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
691
692
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
693

694
} // namespace tl
695
} // namespace tvm