gemm.cc 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
50
51
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
52
53
54
55
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
56
57
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
58
  }
59
60
}

61
62
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
63
  int m_warp = 1, n_warp = 1;
64
65
66
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
  if (allow_wgmma) {
67
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
68
69
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
70
      m_warp = num_warps;
71
      n_warp = 1;
72
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
73
74
      m_warp = 1;
      n_warp = num_warps;
75
76
77
78
79
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
80

81
  if (this->policy == GemmWarpPolicy::kFullRow) {
82
    // Try to partition M first
83
    m_warp = num_warps;
84
85
86
87
88
89
90
91
92
93
94
95
96
    n_warp = 1;

    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
    // to N
    if (this->M % (m_warp * 16) != 0) {
      // Calculate how many warps we can use for M
      int max_m_warps = this->M / 16;
      m_warp = max_m_warps;
      // Use remaining warps for N
      n_warp = num_warps / m_warp;
      if (n_warp == 0)
        n_warp = 1;
    }
97
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
98
99
    // Try to partition N first
    m_warp = 1;
100
    n_warp = num_warps;
101
102
103
104
105
106
107
108
109
110
111
112

    // If N cannot be evenly divided by n_warp*8, try to split remaining warps
    // to M
    if (this->N % (n_warp * 8) != 0) {
      // Calculate how many warps we can use for N
      int max_n_warps = this->N / 8;
      n_warp = max_n_warps;
      // Use remaining warps for M
      m_warp = num_warps / n_warp;
      if (m_warp == 0)
        m_warp = 1;
    }
113
  } else if (this->policy == GemmWarpPolicy::kSquare) {
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    // First calculate the maximum possible warps for each dimension
    int max_m_warps = this->M / 16; // Each warp needs at least 16 elements in M
    int max_n_warps = this->N / 8;  // Each warp needs at least 8 elements in N

    // Calculate the ideal ratio of M/N warps based on the matrix dimensions
    float ideal_ratio = 1.0f;
    if (this->N > 0) {
      ideal_ratio = static_cast<float>(this->M) / this->N;
    }

    // Start with a balanced initial guess
    m_warp = 1;
    n_warp = 1;

    // Try to find the best balanced partition
    int best_m = 1;
    int best_n = 1;
    float best_balance = std::numeric_limits<float>::max();

    // Try all possible combinations that satisfy the constraints
    for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
      int n = num_warps / m;
      if (n > max_n_warps)
        continue;
      if (m * n != num_warps)
        continue;

      // Calculate how balanced this partition is
      float m_per_warp = static_cast<float>(this->M) / (m * 16);
      float n_per_warp = static_cast<float>(this->N) / (n * 8);
      float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

      if (balance < best_balance) {
        best_balance = balance;
        best_m = m;
        best_n = n;
150
151
      }
    }
152
153
154

    m_warp = best_m;
    n_warp = best_n;
155
156
157
158
159
160
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

161
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
162
163
164
165
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }
166
  auto block_size = *as_const_int(T.thread_bounds->extent);
167
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
168
                     (block_size / warp_size % 4 == 0);
169

170
  auto [warp_m, warp_n] =
171
      ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
172

173
174
175
176
177
178
179
180
181
182
183
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
184
  ss << ", " << clear_accum;
185
186
187
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
188
189
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
190
  }
191
192
193
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
194
195
196
197
198
199
200
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
201
202
203
  new_args.push_back(Aptr);
  new_args.push_back(Bptr);
  new_args.push_back(Cptr);
204
205
206
207
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

208
209
210
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
211
212
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
213
214
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
215
216
  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
217
    auto [warp_m, warp_n] =
218
        ComputeWarpPartition(block_size / warp_size, T.target);
219
220
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
221
    results.Set(C, fragment->BindThreadRange(thread_range));
222
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
223
224
225
226
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
227
228
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
229
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
230
      results.Set(A, fragment->BindThreadRange(thread_range));
231
232
233
234
235
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
236
237
238
239
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
240
241
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
242
    auto [warp_m, warp_n] =
243
        ComputeWarpPartition(block_size / warp_size, T.target);
244
245
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
246
    results.Set(C, fragment->BindThreadRange(thread_range));
247
248

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
249
250
251
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
252
253
254
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
255
    } else if (A.scope() == "local.fragment") {
256
257
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
258
      results.Set(A, fragment->BindThreadRange(thread_range));
259
260
261
262
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
263
264
265
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
266
267
268
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
269
    } else if (B.scope() == "local.fragment") {
270
271
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
272
      results.Set(B, fragment->BindThreadRange(thread_range));
273
274
275
276
277
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
278
    bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
279
    auto [warp_m, warp_n] =
280
        ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
281
    auto fragment =
282
283
284
285
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
286
    results.Set(C, fragment->BindThreadRange(thread_range));
287
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
288
289
290
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
291
      const int64_t continuity =
292
293
294
295
          trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
296
    } else {
297
298
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
299
      results.Set(A, fragment->BindThreadRange(thread_range));
300
301
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
302
303
304
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
305
306
307
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
      results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
308
309
310
311
312
313
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    const int warp_size = 64;
314
    auto [warp_m, warp_n] =
315
        ComputeWarpPartition(block_size / warp_size, T.target);
316

317
318
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
319
    results.Set(C, fragment->BindThreadRange(thread_range));
320
321

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
322
323
324
325
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
326
327
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
328
329
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
330
      results.Set(A, fragment->BindThreadRange(thread_range));
331
332
333
334
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
335
336
337
338
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
339
340
341

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
342
343
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
344
      results.Set(B, fragment->BindThreadRange(thread_range));
345
346
347
348
349
350
351
352
353
354
355
356
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
357
358
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
359

360
361
} // namespace tl
} // namespace tvm