"examples/model_compress/vscode:/vscode.git/clone" did not exist on "e773dfcceed9719a6ce8afa0cff35561060f0cc6"
gemm.cc 18 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
  clear_accum = args[9].as<Bool>().value();
50
51
52
53
54
55
  stride_A = args[10].as<IntImm>().value()->value;
  stride_B = args[11].as<IntImm>().value()->value;
  offset_A = args[12].as<IntImm>().value()->value;
  offset_B = args[13].as<IntImm>().value()->value;
  if (args.size() > 14) {
    kPack = args[14].as<IntImm>().value()->value;
56
57
58
59
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
60
61
  if (args.size() > 15) {
    wg_wait = args[15].as<IntImm>().value()->value;
62
  }
63
64
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
85
  int m_warp = 1, n_warp = 1;
86
87
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
88

89
90
91
92
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
93
  if (gemm_inst == GemmInst::kWGMMA) {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
111
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
156
157
158
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
159
160
161

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
162
163
    return {m_warp, n_warp};
  }
164

165
  if (this->policy == GemmWarpPolicy::kFullRow) {
166
    // Try to partition M first
167
    m_warp = num_warps;
168
169
170
171
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
172
    if (this->M % (m_warp * kMPerWarp) != 0) {
173
      // Calculate how many warps we can use for M
174
      int max_m_warps = this->M / kMPerWarp;
175
176
177
178
179
180
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
181
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
182
183
    // Try to partition N first
    m_warp = 1;
184
    n_warp = num_warps;
185
186
187

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
188
    if (this->N % (n_warp * kNPerWarp) != 0) {
189
      // Calculate how many warps we can use for N
190
      int max_n_warps = this->N / kNPerWarp;
191
192
193
194
195
196
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
197
  } else if (this->policy == GemmWarpPolicy::kSquare) {
198
    // First calculate the maximum possible warps for each dimension
199
200
201
202
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
224
225
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
226
227
228
229
230
231
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
232
233
      }
    }
234
235
236

    m_warp = best_m;
    n_warp = best_n;
237
238
239
240
241
242
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

243
244
245
246
bool Gemm::CheckWGMMA() const {
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
247
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
248
      return (!trans_A) && trans_B && K % 32 == 0;
249
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
250
      return (!trans_A) && trans_B && K % 32 == 0;
251
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
252
      return (!trans_A) && trans_B && K % 32 == 0;
253
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
254
255
256
257
258
259
260
261
262
263
264
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
265
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
266
      return (!trans_A) && trans_B && K % 32 == 0;
267
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
268
      return (!trans_A) && trans_B && K % 32 == 0;
269
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
270
      return (!trans_A) && trans_B && K % 32 == 0;
271
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

291
292
293
294
295
296
297
298
299
300
301
302
303
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

304
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
305
  auto block_size = *as_const_int(T.thread_bounds->extent);
306
307
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
308

309
310
311
312
313
314
315
316
317
318
319
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
320
  ss << ", " << clear_accum;
321
322
323
324
  if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
    ss << ", " << stride_A << ", " << stride_B;
    ss << ", " << offset_A << ", " << offset_B;
  }
325
326
327
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
328
  } else if (TargetIsHopper(T.target)) {
329
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
330
  }
331
332
333
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
334
  ss << ">";
335
336
337

  auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
                       Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
338
339
340
  return Evaluate(new_call);
}

341
342
343
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
344
345
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
346
347
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
348
349
350
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

351
  if (TargetIsVolta(T.target)) {
352
353
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
354
    results.Set(C, fragment->BindThreadRange(thread_range));
355
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
356
357
358
359
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
360
361
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
362
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
363
      results.Set(A, fragment->BindThreadRange(thread_range));
364
365
366
367
368
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
369
370
371
372
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
373
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
374
375
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
376
    results.Set(C, fragment->BindThreadRange(thread_range));
377
378

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
379
380
381
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
382
383
384
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
385
    } else if (A.scope() == "local.fragment") {
386
387
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
388
      results.Set(A, fragment->BindThreadRange(thread_range));
389
390
391
392
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
393
394
395
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
396
397
398
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
399
    } else if (B.scope() == "local.fragment") {
400
401
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
402
      results.Set(B, fragment->BindThreadRange(thread_range));
403
404
405
406
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
407
    auto fragment =
408
        gemm_inst == GemmInst::kWGMMA
409
410
411
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
412
    results.Set(C, fragment->BindThreadRange(thread_range));
413
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
414
415
416
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
417
      const int64_t continuity =
418
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
419
      auto ABLayout =
420
          gemm_inst == GemmInst::kWGMMA
421
422
423
424
425
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
426
    } else {
427
428
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
429
      results.Set(A, fragment->BindThreadRange(thread_range));
430
431
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
432
433
434
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
435
436
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
437
      auto ABLayout =
438
          gemm_inst == GemmInst::kWGMMA
439
440
441
442
443
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
444
445
446
447
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
448
449
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
450
    results.Set(C, fragment->BindThreadRange(thread_range));
451
452

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
453
454
455
456
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
457
458
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
459
460
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
461
      results.Set(A, fragment->BindThreadRange(thread_range));
462
463
464
465
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
466
467
468
469
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
470
471
472

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
473
474
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
475
      results.Set(B, fragment->BindThreadRange(thread_range));
476
477
478
479
480
481
482
483
484
485
486
487
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
488
489
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
490

491
492
} // namespace tl
} // namespace tvm