gemm.cc 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
  clear_accum = args[9].as<Bool>().value();
50
51
52
53
54
55
  stride_A = args[10].as<IntImm>().value()->value;
  stride_B = args[11].as<IntImm>().value()->value;
  offset_A = args[12].as<IntImm>().value()->value;
  offset_B = args[13].as<IntImm>().value()->value;
  if (args.size() > 14) {
    kPack = args[14].as<IntImm>().value()->value;
56
57
58
59
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
60
61
  if (args.size() > 15) {
    wg_wait = args[15].as<IntImm>().value()->value;
62
  }
63
64
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
85
  int m_warp = 1, n_warp = 1;
86
87
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
88

89
90
91
92
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
93
  if (gemm_inst == GemmInst::kWGMMA) {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
111
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
156
157
158
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
159
160
161

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
162
163
    return {m_warp, n_warp};
  }
164

165
  if (this->policy == GemmWarpPolicy::kFullRow) {
166
    // Try to partition M first
167
    m_warp = num_warps;
168
169
170
171
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
172
    if (this->M % (m_warp * kMPerWarp) != 0) {
173
      // Calculate how many warps we can use for M
174
      int max_m_warps = this->M / kMPerWarp;
175
176
177
178
179
180
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
181
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
182
183
    // Try to partition N first
    m_warp = 1;
184
    n_warp = num_warps;
185
186
187

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
188
    if (this->N % (n_warp * kNPerWarp) != 0) {
189
      // Calculate how many warps we can use for N
190
      int max_n_warps = this->N / kNPerWarp;
191
192
193
194
195
196
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
197
  } else if (this->policy == GemmWarpPolicy::kSquare) {
198
    // First calculate the maximum possible warps for each dimension
199
200
201
202
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
224
225
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
226
227
228
229
230
231
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
232
233
      }
    }
234
235
236

    m_warp = best_m;
    n_warp = best_n;
237
238
239
240
241
242
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

243
bool Gemm::CheckWGMMA() const {
244
245
246
247
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }

248
249
250
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
251
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
252
      return (!trans_A) && trans_B && K % 32 == 0;
253
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
254
      return (!trans_A) && trans_B && K % 32 == 0;
255
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
256
      return (!trans_A) && trans_B && K % 32 == 0;
257
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
258
259
260
261
262
263
264
265
266
267
268
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
269
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
270
      return (!trans_A) && trans_B && K % 32 == 0;
271
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
272
      return (!trans_A) && trans_B && K % 32 == 0;
273
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
274
      return (!trans_A) && trans_B && K % 32 == 0;
275
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

295
296
297
298
299
300
301
302
303
304
305
306
307
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

308
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
309
  auto block_size = *as_const_int(T.thread_bounds->extent);
310
311
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
312

313
314
315
316
317
318
319
320
321
322
323
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
324
  ss << ", " << clear_accum;
325
326
327
328
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
329
330
331
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
332
  } else if (TargetIsHopper(T.target)) {
333
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
334
  }
335
336
337
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
338
  ss << ">";
339
340
341

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
342
343
344
  return Evaluate(new_call);
}

345
346
347
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
348
349
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
350
351
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
352
353
354
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

355
  if (TargetIsVolta(T.target)) {
356
357
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
358
    results.Set(C, fragment->BindThreadRange(thread_range));
359
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
360
361
362
363
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
364
365
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
366
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
367
      results.Set(A, fragment->BindThreadRange(thread_range));
368
369
370
371
372
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
373
374
375
376
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
377
378
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
379
380
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
381
    results.Set(C, fragment->BindThreadRange(thread_range));
382
383

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
384
385
386
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
387
388
389
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
390
    } else if (A.scope() == "local.fragment") {
391
392
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
393
      results.Set(A, fragment->BindThreadRange(thread_range));
394
395
396
397
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
398
399
400
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
401
402
403
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
404
    } else if (B.scope() == "local.fragment") {
405
406
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
407
      results.Set(B, fragment->BindThreadRange(thread_range));
408
409
410
411
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
412
    auto fragment =
413
        gemm_inst == GemmInst::kWGMMA
414
415
416
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
417
    results.Set(C, fragment->BindThreadRange(thread_range));
418
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
419
420
421
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
422
      const int64_t continuity =
423
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
424
      auto ABLayout =
425
          gemm_inst == GemmInst::kWGMMA
426
427
428
429
430
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
431
    } else {
432
433
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
434
      results.Set(A, fragment->BindThreadRange(thread_range));
435
436
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
437
438
439
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
440
441
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
442
      auto ABLayout =
443
          gemm_inst == GemmInst::kWGMMA
444
445
446
447
448
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
449
    } else {
450
451
452
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
      results.Set(B, fragment->BindThreadRange(thread_range));
453
454
    }
  } else if (TargetIsCDNA(T.target)) {
455
456
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
457
    results.Set(C, fragment->BindThreadRange(thread_range));
458
459

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
460
461
462
463
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
464
465
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
466
467
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
468
      results.Set(A, fragment->BindThreadRange(thread_range));
469
470
471
472
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
473
474
475
476
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
477
478
479

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
480
481
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
482
      results.Set(B, fragment->BindThreadRange(thread_range));
483
484
485
486
487
488
489
490
491
492
493
494
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
495
496
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
497

498
} // namespace tl
499
} // namespace tvm