gemm.cc 11 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
  A = vmap[GetVarFromAccessPtr(args[0])];
  B = vmap[GetVarFromAccessPtr(args[1])];
  C = vmap[GetVarFromAccessPtr(args[2])];
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
44
  K = args[7].as<IntImm>().value()->value;
45
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
46
47
48
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
49
50
51
52
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
53
54
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
55
  }
56
57
}

58
59
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
60
  int m_warp = 1, n_warp = 1;
61
62
63
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
  if (allow_wgmma) {
64
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
65
66
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      m_warp = num_warps;
      ICHECK(this->M % num_warps == 0);
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
      m_warp = 4;
      n_warp = num_warps / 4;
      ICHECK(this->N % n_warp == 0);
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
  if (this->policy == GemmWarpPolicy::kFullRow) {
    m_warp = num_warps;
    ICHECK(this->M % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
    n_warp = num_warps;
    ICHECK(this->N % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kSquare) {
    auto factors = toPrimeFactors(num_warps);
    for (int factor : factors) {
      bool M_divisible = (this->M % (factor * m_warp)) == 0;
      bool N_divisible = (this->N % (factor * n_warp)) == 0;
      if (M_divisible && N_divisible) {
        if (this->M / m_warp >= this->N / n_warp)
          m_warp *= factor;
        else
          n_warp *= factor;
      } else if (M_divisible) {
        m_warp *= factor;
      } else if (N_divisible) {
        n_warp *= factor;
      } else {
        ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
                  << " with num_warps " << num_warps;
      }
    }
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  // TODO: perform more checks here
  return {m_warp, n_warp};
}

110
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
111
112
113
114
115
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }

116
117
118
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
                     (T.block_size / warp_size % 4 == 0);

119
  auto [warp_m, warp_n] =
120
121
      ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);

122
123
124
125
126
127
128
129
130
131
132
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
133
  ss << ", " << clear_accum;
134
135
136
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
137
138
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
139
  }
140
141
142
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
  new_args.push_back(A_buffer.access_ptr(1));
  new_args.push_back(B_buffer.access_ptr(1));
  new_args.push_back(C_buffer.access_ptr(3));
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

157
158
159
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
160
161
162
163
164
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");

  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
165
166
167
168
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
169
170
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
171
172
173
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
                                           *as_const_int(A->shape[1]), true,
                                           trans_A ? 1 : 2));
174
175
176
177
178
179
180
181
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
      results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
182
183
184
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
                                         *as_const_int(B->shape[1]), false,
                                         trans_B ? 2 : 1));
185
186
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
187
188
189
190
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
191
192
193
    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
194
195
196
197
198
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
199
200
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
201
202
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
203
204
205
206
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
207
208
209
210
211
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
212
213
214
215
216
217
218
219
    } else if (B.scope() == "local.fragment") {
      ICHECK(trans_B == false);
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
220
221
222
223
224
    bool maybe_wgmma = (this->M >= 64) && (T.block_size / warp_size % 4 == 0);
    if (!maybe_wgmma) {
      LOG(WARNING)
          << "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
    }
225
    auto [warp_m, warp_n] =
226
        ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);
227
    auto fragment =
228
229
230
231
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
232
233
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
234
235
236
237
238
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      const int64_t continuity =
          trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
      results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
239
240
241
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else {
      ICHECK(trans_A == false);
242
243
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
244
245
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
246
247
248
249
250
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
      results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
251
252
253
254
255
256
257
258
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";

    const int warp_size = 64;
259
260
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
261

262
263
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
264
265
266
267

    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
268

269
270
      // Make Linear Memory Access Layout
      // auto shared_layout =
271
272
      //     makeGemmLayoutLinear(*as_const_int(A->shape[0]),
      //     *as_const_int(A->shape[1]));
273
274

      // Make Swizzle or Pad Layout
275
276
277
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
                                                *as_const_int(A->shape[1]),
                                                A->dtype.bits(), kPack);
278
279
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
280
281
      results.Set(
          A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
282
283
284
285
286
287
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
      // Make Linear Memory Access Layout
      // auto shared_layout =
288
289
      //     makeGemmLayoutLinear(*as_const_int(B->shape[0]),
      //     *as_const_int(B->shape[1]));
290
291

      // Make Swizzle or Pad Layout
292
293
294
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
                                                *as_const_int(B->shape[1]),
                                                B->dtype.bits(), kPack);
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
311
312
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
313

314
315
} // namespace tl
} // namespace tvm