gemm.cc 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

9
#include "builtin.h"
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
13
#include <tvm/tir/transform.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
  A = vmap[GetVarFromAccessPtr(args[0])];
  B = vmap[GetVarFromAccessPtr(args[1])];
  C = vmap[GetVarFromAccessPtr(args[2])];
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
44
  K = args[7].as<IntImm>().value()->value;
45
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
46
47
48
  clear_accum = args[9].as<Bool>().value();
  if (args.size() > 10) {
    kPack = args[10].as<IntImm>().value()->value;
49
50
51
52
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
53
54
  if (args.size() > 11) {
    wg_wait = args[11].as<IntImm>().value()->value;
55
  }
56
57
}

58
59
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
60
  int m_warp = 1, n_warp = 1;
61
62
63
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
  if (allow_wgmma) {
64
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
65
66
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
      m_warp = num_warps;
      ICHECK(this->M % num_warps == 0);
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
      m_warp = 4;
      n_warp = num_warps / 4;
      ICHECK(this->N % n_warp == 0);
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
  if (this->policy == GemmWarpPolicy::kFullRow) {
    m_warp = num_warps;
    ICHECK(this->M % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
    n_warp = num_warps;
    ICHECK(this->N % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kSquare) {
    auto factors = toPrimeFactors(num_warps);
    for (int factor : factors) {
      bool M_divisible = (this->M % (factor * m_warp)) == 0;
      bool N_divisible = (this->N % (factor * n_warp)) == 0;
      if (M_divisible && N_divisible) {
90
91
92
93
        // put N dimension first
        // because usually n in mma
        // is more smaller than m
        if (this->N / n_warp >= this->M / m_warp)
94
          n_warp *= factor;
95
96
        else
          m_warp *= factor;
97
98
      } else if (N_divisible) {
        n_warp *= factor;
99
100
      } else if (M_divisible) {
        m_warp *= factor;
101
102
103
104
105
106
107
108
109
110
111
      } else {
        ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
                  << " with num_warps " << num_warps;
      }
    }
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  return {m_warp, n_warp};
}

112
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
113
114
115
116
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }
117
  auto block_size = *as_const_int(T.thread_bounds->extent);
118
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
119
                     (block_size / warp_size % 4 == 0);
120

121
  auto [warp_m, warp_n] =
122
      ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
123

124
125
126
127
128
129
130
131
132
133
134
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
135
  ss << ", " << clear_accum;
136
137
138
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
139
140
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
141
  }
142
143
144
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
  new_args.push_back(A_buffer.access_ptr(1));
  new_args.push_back(B_buffer.access_ptr(1));
  new_args.push_back(C_buffer.access_ptr(3));
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

159
160
161
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
162
163
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");
164
  auto block_size = *as_const_int(T.thread_bounds->extent);
165
166
  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
167
    auto [warp_m, warp_n] =
168
        ComputeWarpPartition(block_size / warp_size, T.target);
169
170
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
171
172
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
173
174
175
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
                                           *as_const_int(A->shape[1]), true,
                                           trans_A ? 1 : 2));
176
177
178
179
180
181
182
183
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
      results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
184
185
186
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
                                         *as_const_int(B->shape[1]), false,
                                         trans_B ? 2 : 1));
187
188
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
189
    auto [warp_m, warp_n] =
190
        ComputeWarpPartition(block_size / warp_size, T.target);
191
192
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
193
194
195
    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
196
197
198
199
200
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
201
202
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
203
204
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
205
206
207
208
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
209
210
211
212
213
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
214
    } else if (B.scope() == "local.fragment") {
215
216
      ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
                                  "please raise an issue if you see this";
217
218
219
220
221
222
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
223
    bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
224
    auto [warp_m, warp_n] =
225
        ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
226
    auto fragment =
227
228
229
230
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
231
232
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
233
234
235
236
237
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      const int64_t continuity =
          trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
      results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
238
239
240
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else {
      ICHECK(trans_A == false);
241
242
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
243
244
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
245
246
247
248
249
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
      results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
250
251
252
253
254
255
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    const int warp_size = 64;
256
    auto [warp_m, warp_n] =
257
        ComputeWarpPartition(block_size / warp_size, T.target);
258

259
260
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
261
262
263
264

    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
265

266
267
      // Make Linear Memory Access Layout
      // auto shared_layout =
268
269
      //     makeGemmLayoutLinear(*as_const_int(A->shape[0]),
      //     *as_const_int(A->shape[1]));
270
271

      // Make Swizzle or Pad Layout
272
273
274
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
                                                *as_const_int(A->shape[1]),
                                                A->dtype.bits(), kPack);
275
276
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
277
278
      results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
                                           A->dtype.bits(), trans_A));
279
280
281
282
283
284
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
      // Make Linear Memory Access Layout
      // auto shared_layout =
285
286
      //     makeGemmLayoutLinear(*as_const_int(B->shape[0]),
      //     *as_const_int(B->shape[1]));
287
288

      // Make Swizzle or Pad Layout
289
290
291
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
                                                *as_const_int(B->shape[1]),
                                                B->dtype.bits(), kPack);
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
308
309
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
310

311
312
} // namespace tl
} // namespace tvm