gemm.cc 17.3 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
50
51
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
52
53
54
55
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
56
57
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
58
  }
59
60
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
81
  int m_warp = 1, n_warp = 1;
82
83
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
84

85
86
87
88
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
89
  if (gemm_inst == GemmInst::kWGMMA) {
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
107
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
152
153
154
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
155
156
157

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
158
159
    return {m_warp, n_warp};
  }
160

161
  if (this->policy == GemmWarpPolicy::kFullRow) {
162
    // Try to partition M first
163
    m_warp = num_warps;
164
165
166
167
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
168
    if (this->M % (m_warp * kMPerWarp) != 0) {
169
      // Calculate how many warps we can use for M
170
      int max_m_warps = this->M / kMPerWarp;
171
172
173
174
175
176
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
177
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
178
179
    // Try to partition N first
    m_warp = 1;
180
    n_warp = num_warps;
181
182
183

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
184
    if (this->N % (n_warp * kNPerWarp) != 0) {
185
      // Calculate how many warps we can use for N
186
      int max_n_warps = this->N / kNPerWarp;
187
188
189
190
191
192
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
193
  } else if (this->policy == GemmWarpPolicy::kSquare) {
194
    // First calculate the maximum possible warps for each dimension
195
196
197
198
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
220
221
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
222
223
224
225
226
227
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
228
229
      }
    }
230
231
232

    m_warp = best_m;
    n_warp = best_n;
233
234
235
236
237
238
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

239
240
241
242
bool Gemm::CheckWGMMA() const {
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
243
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
244
      return (!trans_A) && trans_B && K % 32 == 0;
245
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
246
      return (!trans_A) && trans_B && K % 32 == 0;
247
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
248
      return (!trans_A) && trans_B && K % 32 == 0;
249
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
250
251
252
253
254
255
256
257
258
259
260
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
261
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
262
      return (!trans_A) && trans_B && K % 32 == 0;
263
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
264
      return (!trans_A) && trans_B && K % 32 == 0;
265
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
266
      return (!trans_A) && trans_B && K % 32 == 0;
267
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

287
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
288
  auto block_size = *as_const_int(T.thread_bounds->extent);
289
290
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
291

292
293
294
295
296
297
298
299
300
301
302
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
303
  ss << ", " << clear_accum;
304
305
306
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
307
  } else if (TargetIsHopper(T.target)) {
308
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
309
  }
310
311
312
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
313
  ss << ">";
314
315
316

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
317
318
319
  return Evaluate(new_call);
}

320
321
322
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
323
324
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
325
326
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
327
328
329
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

330
  if (TargetIsVolta(T.target)) {
331
332
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
333
    results.Set(C, fragment->BindThreadRange(thread_range));
334
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
335
336
337
338
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
339
340
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
341
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
342
      results.Set(A, fragment->BindThreadRange(thread_range));
343
344
345
346
347
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
348
349
350
351
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
352
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
353
354
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
355
    results.Set(C, fragment->BindThreadRange(thread_range));
356
357

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
358
359
360
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
361
362
363
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
364
    } else if (A.scope() == "local.fragment") {
365
366
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
367
      results.Set(A, fragment->BindThreadRange(thread_range));
368
369
370
371
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
372
373
374
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
375
376
377
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
378
    } else if (B.scope() == "local.fragment") {
379
380
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
381
      results.Set(B, fragment->BindThreadRange(thread_range));
382
383
384
385
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
386
    auto fragment =
387
        gemm_inst == GemmInst::kWGMMA
388
389
390
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
391
    results.Set(C, fragment->BindThreadRange(thread_range));
392
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
393
394
395
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
396
      const int64_t continuity =
397
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
398
      auto ABLayout =
399
          gemm_inst == GemmInst::kWGMMA
400
401
402
403
404
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
405
    } else {
406
407
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
408
      results.Set(A, fragment->BindThreadRange(thread_range));
409
410
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
411
412
413
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
414
415
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
416
      auto ABLayout =
417
          gemm_inst == GemmInst::kWGMMA
418
419
420
421
422
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
423
424
425
426
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
427
428
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
429
    results.Set(C, fragment->BindThreadRange(thread_range));
430
431

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
432
433
434
435
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
436
437
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
438
439
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
440
      results.Set(A, fragment->BindThreadRange(thread_range));
441
442
443
444
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
445
446
447
448
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
449
450
451

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
452
453
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
454
      results.Set(B, fragment->BindThreadRange(thread_range));
455
456
457
458
459
460
461
462
463
464
465
466
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
467
468
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
469

470
471
} // namespace tl
} // namespace tvm