"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c1d244d9349d05fd2cd029c7d455fa8a7f15ed63"
gemm.cc 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
50
51
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
52
53
54
55
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
56
57
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
58
  }
59
60
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
                     (num_warps % 4 == 0) && CheckWGMMA();
  if (allow_wgmma) {
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
  } else if (TargetIsCuda(target)) {
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
  }
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
                                               GemmInst gemm_inst,
                                               Target target) const {
  int num_warps = block_size / TargetGetWarpSize(target);
81
  int m_warp = 1, n_warp = 1;
82
83
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
84

85
86
87
88
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
89
  if (gemm_inst == GemmInst::kWGMMA) {
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
107
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
152
153
154
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
155
156
157

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
158
159
    return {m_warp, n_warp};
  }
160

161
  if (this->policy == GemmWarpPolicy::kFullRow) {
162
    // Try to partition M first
163
    m_warp = num_warps;
164
165
166
167
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
168
    if (this->M % (m_warp * kMPerWarp) != 0) {
169
      // Calculate how many warps we can use for M
170
      int max_m_warps = this->M / kMPerWarp;
171
172
173
174
175
176
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
177
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
178
179
    // Try to partition N first
    m_warp = 1;
180
    n_warp = num_warps;
181
182
183

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
184
    if (this->N % (n_warp * kNPerWarp) != 0) {
185
      // Calculate how many warps we can use for N
186
      int max_n_warps = this->N / kNPerWarp;
187
188
189
190
191
192
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
193
  } else if (this->policy == GemmWarpPolicy::kSquare) {
194
    // First calculate the maximum possible warps for each dimension
195
196
197
198
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
220
221
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
222
223
224
225
226
227
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
228
229
      }
    }
230
231
232

    m_warp = best_m;
    n_warp = best_n;
233
234
235
236
237
238
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

239
240
241
242
bool Gemm::CheckWGMMA() const {
  if (C->dtype == DataType::Float(16)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
243
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
244
      return (!trans_A) && trans_B && K % 32 == 0;
245
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
246
      return (!trans_A) && trans_B && K % 32 == 0;
247
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
248
      return (!trans_A) && trans_B && K % 32 == 0;
249
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
250
251
252
253
254
255
256
257
258
259
260
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Float(32)) {
    if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::BFloat(16) &&
             B->dtype == DataType::BFloat(16))
      return K % 16 == 0;
    else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
      return (!trans_A) && trans_B && K % 8 == 0;
261
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
262
      return (!trans_A) && trans_B && K % 32 == 0;
263
    else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
264
      return (!trans_A) && trans_B && K % 32 == 0;
265
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
266
      return (!trans_A) && trans_B && K % 32 == 0;
267
    else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else if (C->dtype == DataType::Int(32)) {
    if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
      return (!trans_A) && trans_B && K % 32 == 0;
    else
      return false;
  } else {
    return false;
  }
}

287
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
288
  auto block_size = *as_const_int(T.thread_bounds->extent);
289
290
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
291

292
293
294
295
296
297
298
299
300
301
302
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
303
  ss << ", " << clear_accum;
304
305
306
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
307
  } else if (TargetIsHopper(T.target)) {
308
    ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
309
  }
310
311
312
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
313
314
315
316
317
318
319
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
320
321
322
  new_args.push_back(Aptr);
  new_args.push_back(Bptr);
  new_args.push_back(Cptr);
323
324
325
326
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

327
328
329
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
330
331
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
332
333
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
334
335
336
  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
  auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

337
  if (TargetIsVolta(T.target)) {
338
339
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
340
    results.Set(C, fragment->BindThreadRange(thread_range));
341
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
342
343
344
345
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
346
347
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
348
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
349
      results.Set(A, fragment->BindThreadRange(thread_range));
350
351
352
353
354
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
355
356
357
358
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
359
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
360
361
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
362
    results.Set(C, fragment->BindThreadRange(thread_range));
363
364

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
365
366
367
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
368
369
370
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
371
    } else if (A.scope() == "local.fragment") {
372
373
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
374
      results.Set(A, fragment->BindThreadRange(thread_range));
375
376
377
378
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
379
380
381
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
382
383
384
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
385
    } else if (B.scope() == "local.fragment") {
386
387
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
388
      results.Set(B, fragment->BindThreadRange(thread_range));
389
390
391
392
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
393
    auto fragment =
394
        gemm_inst == GemmInst::kWGMMA
395
396
397
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
398
    results.Set(C, fragment->BindThreadRange(thread_range));
399
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
400
401
402
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
403
      const int64_t continuity =
404
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
405
      auto ABLayout =
406
          gemm_inst == GemmInst::kWGMMA
407
408
409
410
411
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       A->dtype.bits(), trans_A ? 1 : 2)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 A->dtype.bits(), trans_A ? 1 : 2);
      results.Set(A, ABLayout);
412
    } else {
413
414
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
415
      results.Set(A, fragment->BindThreadRange(thread_range));
416
417
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
418
419
420
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
421
422
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
423
      auto ABLayout =
424
          gemm_inst == GemmInst::kWGMMA
425
426
427
428
429
              ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                       B->dtype.bits(), trans_B ? 2 : 1)
              : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                 B->dtype.bits(), trans_B ? 2 : 1);
      results.Set(B, ABLayout);
430
431
432
433
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
434
435
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
436
    results.Set(C, fragment->BindThreadRange(thread_range));
437
438

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
439
440
441
442
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
443
444
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
445
446
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
447
      results.Set(A, fragment->BindThreadRange(thread_range));
448
449
450
451
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
452
453
454
455
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
456
457
458

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
459
460
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
461
      results.Set(B, fragment->BindThreadRange(thread_range));
462
463
464
465
466
467
468
469
470
471
472
473
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
474
475
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
476

477
478
} // namespace tl
} // namespace tvm