gemm.cc 15.1 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
50
51
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
52
53
54
55
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
56
57
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
58
  }
59
60
}

61
62
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
63
  int m_warp = 1, n_warp = 1;
64
65
  constexpr int kMPerWarp = 16; // Rows processed by a single warp
  constexpr int kNPerWarp = 8;  // Columns processed by a single warp
66
67
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
68
69
70
71
  ICHECK(this->M % kMPerWarp == 0)
      << "M must be divisible by " << kMPerWarp << ", but got " << this->M;
  ICHECK(this->N % kNPerWarp == 0)
      << "N must be divisible by " << kNPerWarp << ", but got " << this->N;
72
  if (allow_wgmma) {
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

    constexpr int kGroup = 4; // Number of warps in a warp-group

    m_warp = kGroup; // Initially, only one warp-group on M dimension
    n_warp = num_warps / m_warp; // Rest all on N dimension

    if (this->policy == GemmWarpPolicy::kFullRow) {
      // Try to put as many warp-groups as possible on M dimension
      // (decreasing multiples of 4, ensuring divisibility by M)
      for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
        if (this->M % (cand * kMPerWarp) == 0) {
          m_warp = cand;
          n_warp = num_warps / m_warp;
          break;
        }
      }
90
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
      // Try to use warps on N dimension; if N is not divisible, split excess
      // groups to M
      int cand_n = n_warp;                       // Initially assume all on N
      if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
        int max_n = this->N / kNPerWarp;
        // Find a feasible n_warp from max possible downwards, ensuring
        // num_warps/n_warp is multiple of 4
        for (int n = std::min(cand_n, max_n); n >= 1; --n) {
          if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
            n_warp = n;
            m_warp = num_warps / n_warp;
            break;
          }
        }
      }
    } else if (this->policy == GemmWarpPolicy::kSquare) {
      // Exhaustive search, but m must be multiple of 4
      int max_m = this->M / kMPerWarp;
      int max_n = this->N / kNPerWarp;

      float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;

      float best_score = std::numeric_limits<float>::max();
      int best_m = kGroup, best_n = n_warp;

      for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
        if (num_warps % m)
          continue;
        int n = num_warps / m;
        if (n > max_n)
          continue;

        float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
        float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
        float score = std::abs(m_per_warp / n_per_warp - ideal);

        if (score < best_score) {
          best_score = score;
          best_m = m;
          best_n = n;
        }
      }
      m_warp = best_m;
      n_warp = best_n;
135
136
137
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
138
139
140

    ICHECK(m_warp * n_warp == num_warps)
        << "m_warp * n_warp must equal num_warps";
141
142
    return {m_warp, n_warp};
  }
143

144
  if (this->policy == GemmWarpPolicy::kFullRow) {
145
    // Try to partition M first
146
    m_warp = num_warps;
147
148
149
150
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
151
    if (this->M % (m_warp * kMPerWarp) != 0) {
152
      // Calculate how many warps we can use for M
153
      int max_m_warps = this->M / kMPerWarp;
154
155
156
157
158
159
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
160
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
161
162
    // Try to partition N first
    m_warp = 1;
163
    n_warp = num_warps;
164
165
166

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
167
    if (this->N % (n_warp * kNPerWarp) != 0) {
168
      // Calculate how many warps we can use for N
169
      int max_n_warps = this->N / kNPerWarp;
170
171
172
173
174
175
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
176
  } else if (this->policy == GemmWarpPolicy::kSquare) {
177
    // First calculate the maximum possible warps for each dimension
178
179
180
181
    int max_m_warps =
        this->M / kMPerWarp; // Each warp needs at least 16 elements in M
    int max_n_warps =
        this->N / kNPerWarp; // Each warp needs at least 8 elements in N
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;

      // Calculate how balanced this partition is
203
204
      float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
      float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
205
206
207
208
209
210
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
211
212
      }
    }
213
214
215

    m_warp = best_m;
    n_warp = best_n;
216
217
218
219
220
221
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

222
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
223
224
225
226
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }
227
  auto block_size = *as_const_int(T.thread_bounds->extent);
228
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
229
                     (block_size / warp_size % 4 == 0);
230

231
  auto [warp_m, warp_n] =
232
      ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
233

234
235
236
237
238
239
240
241
242
243
244
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
245
  ss << ", " << clear_accum;
246
247
248
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
249
250
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
251
  }
252
253
254
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
255
256
257
258
259
260
261
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
262
263
264
  new_args.push_back(Aptr);
  new_args.push_back(Bptr);
  new_args.push_back(Cptr);
265
266
267
268
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

269
270
271
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
272
273
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
274
275
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
276
277
  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
278
    auto [warp_m, warp_n] =
279
        ComputeWarpPartition(block_size / warp_size, T.target);
280
281
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
282
    results.Set(C, fragment->BindThreadRange(thread_range));
283
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
284
285
286
287
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
288
289
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
290
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
291
      results.Set(A, fragment->BindThreadRange(thread_range));
292
293
294
295
296
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
297
298
299
300
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
301
302
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
303
    auto [warp_m, warp_n] =
304
        ComputeWarpPartition(block_size / warp_size, T.target);
305
306
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
307
    results.Set(C, fragment->BindThreadRange(thread_range));
308
309

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
310
311
312
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
313
314
315
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
316
    } else if (A.scope() == "local.fragment") {
317
318
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
319
      results.Set(A, fragment->BindThreadRange(thread_range));
320
321
322
323
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
324
325
326
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
327
328
329
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
330
    } else if (B.scope() == "local.fragment") {
331
332
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
333
      results.Set(B, fragment->BindThreadRange(thread_range));
334
335
336
337
338
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
339
    bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
340
    auto [warp_m, warp_n] =
341
        ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
342
    auto fragment =
343
344
345
346
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
347
    results.Set(C, fragment->BindThreadRange(thread_range));
348
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
349
350
351
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
352
      const int64_t continuity =
353
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
354
355
356
      results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
                                            mat_continuous, A->dtype.bits(),
                                            trans_A ? 1 : 2));
357
    } else {
358
359
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
360
      results.Set(A, fragment->BindThreadRange(thread_range));
361
362
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
363
364
365
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
366
367
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
368
369
370
      results.Set(B,
                  makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
                                         B->dtype.bits(), trans_B ? 2 : 1));
371
372
373
374
375
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    const int warp_size = 64;
376
    auto [warp_m, warp_n] =
377
        ComputeWarpPartition(block_size / warp_size, T.target);
378

379
380
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
381
    results.Set(C, fragment->BindThreadRange(thread_range));
382
383

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
384
385
386
387
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
388
389
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
390
391
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
392
      results.Set(A, fragment->BindThreadRange(thread_range));
393
394
395
396
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
397
398
399
400
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
401
402
403

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
404
405
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
406
      results.Set(B, fragment->BindThreadRange(thread_range));
407
408
409
410
411
412
413
414
415
416
417
418
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
419
420
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
421

422
423
} // namespace tl
} // namespace tvm