gemm.cc 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
37
38
39
40
41
42
  Aptr = args[0];
  Bptr = args[1];
  Cptr = args[2];
  A = vmap[GetVarFromAccessPtr(Aptr)];
  B = vmap[GetVarFromAccessPtr(Bptr)];
  C = vmap[GetVarFromAccessPtr(Cptr)];
43
44
45
46
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
49
50
51
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
52
53
54
55
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
56
57
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
58
  }
59
60
}

61
62
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
63
  int m_warp = 1, n_warp = 1;
64
65
66
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
  if (allow_wgmma) {
67
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
68
69
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
70
      m_warp = num_warps;
71
      ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
72
73
74
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
      m_warp = 4;
      n_warp = num_warps / 4;
75
      ICHECK(this->N % n_warp == 0) << this->N << " % " << n_warp;
76
77
78
79
80
81
82
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
  if (this->policy == GemmWarpPolicy::kFullRow) {
    m_warp = num_warps;
83
    ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
84
85
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
    n_warp = num_warps;
86
    ICHECK(this->N % num_warps == 0) << this->N << " % " << num_warps;
87
88
89
90
91
92
  } else if (this->policy == GemmWarpPolicy::kSquare) {
    auto factors = toPrimeFactors(num_warps);
    for (int factor : factors) {
      bool M_divisible = (this->M % (factor * m_warp)) == 0;
      bool N_divisible = (this->N % (factor * n_warp)) == 0;
      if (M_divisible && N_divisible) {
93
94
95
96
        // put N dimension first
        // because usually n in mma
        // is more smaller than m
        if (this->N / n_warp >= this->M / m_warp)
97
          n_warp *= factor;
98
99
        else
          m_warp *= factor;
100
101
      } else if (N_divisible) {
        n_warp *= factor;
102
103
      } else if (M_divisible) {
        m_warp *= factor;
104
105
106
107
108
109
110
111
112
113
114
      } else {
        ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
                  << " with num_warps " << num_warps;
      }
    }
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

115
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
116
117
118
119
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }
120
  auto block_size = *as_const_int(T.thread_bounds->extent);
121
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
122
                     (block_size / warp_size % 4 == 0);
123

124
  auto [warp_m, warp_n] =
125
      ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
126

127
128
129
130
131
132
133
134
135
136
137
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
138
  ss << ", " << clear_accum;
139
140
141
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
142
143
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
144
  }
145
146
147
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
148
149
150
151
152
153
154
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
155
156
157
  new_args.push_back(Aptr);
  new_args.push_back(Bptr);
  new_args.push_back(Cptr);
158
159
160
161
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

162
163
164
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
165
166
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
167
168
  auto thread_range = T.thread_bounds;
  auto block_size = *as_const_int(thread_range->extent);
169
170
  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
171
    auto [warp_m, warp_n] =
172
        ComputeWarpPartition(block_size / warp_size, T.target);
173
174
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
175
    results.Set(C, fragment->BindThreadRange(thread_range));
176
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
177
178
179
180
      int dim_A = A->shape.size();
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
                                           *as_const_int(A->shape[dim_A - 1]),
                                           true, trans_A ? 1 : 2));
181
182
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
183
      auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
184
      results.Set(A, fragment->BindThreadRange(thread_range));
185
186
187
188
189
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
190
191
192
193
    int dim_B = B->shape.size();
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
                                         *as_const_int(B->shape[dim_B - 1]),
                                         false, trans_B ? 2 : 1));
194
195
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
196
    auto [warp_m, warp_n] =
197
        ComputeWarpPartition(block_size / warp_size, T.target);
198
199
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
200
    results.Set(C, fragment->BindThreadRange(thread_range));
201
202

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
203
204
205
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
206
207
208
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
209
    } else if (A.scope() == "local.fragment") {
210
211
      auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                        A->dtype.bits(), trans_A);
212
      results.Set(A, fragment->BindThreadRange(thread_range));
213
214
215
216
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
217
218
219
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
220
221
222
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
223
    } else if (B.scope() == "local.fragment") {
224
225
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
226
      results.Set(B, fragment->BindThreadRange(thread_range));
227
228
229
230
231
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
232
    bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
233
    auto [warp_m, warp_n] =
234
        ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
235
    auto fragment =
236
237
238
239
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
240
    results.Set(C, fragment->BindThreadRange(thread_range));
241
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
242
243
244
      int dim_A = A->shape.size();
      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
245
246
247
      const int64_t continuity =
          trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
      results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
248
249
250
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else {
      ICHECK(trans_A == false);
251
252
      auto fragment =
          makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
253
      results.Set(A, fragment->BindThreadRange(thread_range));
254
255
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
256
257
258
      int dim_B = B->shape.size();
      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
259
260
261
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
      results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
262
263
264
265
266
267
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    const int warp_size = 64;
268
    auto [warp_m, warp_n] =
269
        ComputeWarpPartition(block_size / warp_size, T.target);
270

271
272
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
273
    results.Set(C, fragment->BindThreadRange(thread_range));
274
275

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
276
277
278
279
      int dim_A = A->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(A->shape[dim_A - 2]),
          *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
280
281
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
282
283
      auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                            A->dtype.bits(), trans_A);
284
      results.Set(A, fragment->BindThreadRange(thread_range));
285
286
287
288
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
289
290
291
292
      int dim_B = B->shape.size();
      auto shared_layout = makeGemmABLayoutCDNA(
          *as_const_int(B->shape[dim_B - 2]),
          *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
293
294
295

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
296
297
      auto fragment =
          makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
298
      results.Set(B, fragment->BindThreadRange(thread_range));
299
300
301
302
303
304
305
306
307
308
309
310
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
311
312
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
313

314
315
} // namespace tl
} // namespace tvm