gemm.cc 18 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
  clear_accum = args[9].as<Bool>().value();
50
51
52
53
54
55
  stride_A = args[10].as<IntImm>().value()->value;
  stride_B = args[11].as<IntImm>().value()->value;
  offset_A = args[12].as<IntImm>().value()->value;
  offset_B = args[13].as<IntImm>().value()->value;
  if (args.size() > 14) {
    kPack = args[14].as<IntImm>().value()->value;
56
57
58
59
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
60
61
  if (args.size() > 15) {
    wg_wait = args[15].as<IntImm>().value()->value;
62
  }
63
64
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
85
  int m_warp = 1, n_warp = 1;
86
87
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
88

89
90
91
92
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
93
  if (gemm_inst == GemmInst::kWGMMA) {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
111
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
156
157
158
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
159
160
161

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
162
163
    return {m_warp, n_warp};
  }
164

165
  if (this->policy == GemmWarpPolicy::kFullRow) {
166
    // Try to partition M first
167
    m_warp = num_warps;
168
169
170
171
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
172
    if (this->M % (m_warp * kMPerWarp) != 0) {
173
      // Calculate how many warps we can use for M
174
      int max_m_warps = this->M / kMPerWarp;
175
176
177
178
179
180
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
181
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
182
183
    // Try to partition N first
    m_warp = 1;
184
    n_warp = num_warps;
185
186
187

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
188
    if (this->N % (n_warp * kNPerWarp) != 0) {
189
      // Calculate how many warps we can use for N
190
      int max_n_warps = this->N / kNPerWarp;
191
192
193
194
195
196
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
197
  } else if (this->policy == GemmWarpPolicy::kSquare) {
198
    // First calculate the maximum possible warps for each dimension
199
200
201
202
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
224
225
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
226
227
228
229
230
231
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
232
233
      }
    }
234
235
236

    m_warp = best_m;
    n_warp = best_n;
237
238
239
240
241
242
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

243
244
245
246
bool Gemm::CheckWGMMA() const {
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
247
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
248
      return (!trans_A) && trans_B && K % 32 == 0;
249
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
250
      return (!trans_A) && trans_B && K % 32 == 0;
251
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
252
      return (!trans_A) && trans_B && K % 32 == 0;
253
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
254
255
256
257
258
259
260
261
262
263
264
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
265
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
266
      return (!trans_A) && trans_B && K % 32 == 0;
267
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
268
      return (!trans_A) && trans_B && K % 32 == 0;
269
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
270
      return (!trans_A) && trans_B && K % 32 == 0;
271
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

291
292
293
294
295
296
297
298
299
300
301
302
303
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

304
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
305
  auto block_size = *as_const_int(T.thread_bounds->extent);
306
307
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
308

309
310
311
312
313
314
315
316
317
318
319
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
320
  ss << ", " << clear_accum;
321
322
323
324
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
325
326
327
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
328
  } else if (TargetIsHopper(T.target)) {
329
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
330
  }
331
332
333
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
334
  ss << ">";
335
336
337

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
338
339
340
  return Evaluate(new_call);
}

341
342
343
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
344
345
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
346
347
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
348
349
350
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

351
  if (TargetIsVolta(T.target)) {
352
353
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
354
    results.Set(C, fragment->BindThreadRange(thread_range));
355
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
356
357
358
359
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
360
361
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
362
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
363
      results.Set(A, fragment->BindThreadRange(thread_range));
364
365
366
367
368
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
369
370
371
372
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
373
374
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
             TargetIsSM120(T.target)) {
375
376
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
377
    results.Set(C, fragment->BindThreadRange(thread_range));
378
379

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
380
381
382
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
383
384
385
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
386
    } else if (A.scope() == "local.fragment") {
387
388
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
389
      results.Set(A, fragment->BindThreadRange(thread_range));
390
391
392
393
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
394
395
396
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
397
398
399
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
400
    } else if (B.scope() == "local.fragment") {
401
402
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
403
      results.Set(B, fragment->BindThreadRange(thread_range));
404
405
406
407
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
408
    auto fragment =
409
        gemm_inst == GemmInst::kWGMMA
410
411
412
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
413
    results.Set(C, fragment->BindThreadRange(thread_range));
414
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
415
416
417
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
418
      const int64_t continuity =
419
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
420
      auto ABLayout =
421
          gemm_inst == GemmInst::kWGMMA
422
423
424
425
426
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
427
    } else {
428
429
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
430
      results.Set(A, fragment->BindThreadRange(thread_range));
431
432
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
433
434
435
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
436
437
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
438
      auto ABLayout =
439
          gemm_inst == GemmInst::kWGMMA
440
441
442
443
444
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
445
446
447
448
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
449
450
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
451
    results.Set(C, fragment->BindThreadRange(thread_range));
452
453

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
454
455
456
457
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
458
459
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
460
461
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
462
      results.Set(A, fragment->BindThreadRange(thread_range));
463
464
465
466
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
467
468
469
470
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
471
472
473

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
474
475
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
476
      results.Set(B, fragment->BindThreadRange(thread_range));
477
478
479
480
481
482
483
484
485
486
487
488
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
489
490
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
491

492
493
} // namespace tl
} // namespace tvm