gemm.cc 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

12
#include "builtin.h"
13
14
15
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
16
#include <tvm/tir/transform.h>
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
  A = vmap[GetVarFromAccessPtr(args[0])];
  B = vmap[GetVarFromAccessPtr(args[1])];
  C = vmap[GetVarFromAccessPtr(args[2])];
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
47
  K = args[7].as<IntImm>().value()->value;
48
49
50
51
52
53
54
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
  if (args.size() > 9) {
    kPack = args[9].as<IntImm>().value()->value;
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
55
56
57
  if (args.size() > 10) {
    wg_wait = args[10].as<IntImm>().value()->value;
  }
58
59
}

60
61
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
                                               bool maybe_hopper_wgmma) const {
62
  int m_warp = 1, n_warp = 1;
63
64
65
  bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
                     (this->M >= 64) && (num_warps % 4 == 0);
  if (allow_wgmma) {
66
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
67
68
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
      m_warp = num_warps;
      ICHECK(this->M % num_warps == 0);
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
      m_warp = 4;
      n_warp = num_warps / 4;
      ICHECK(this->N % n_warp == 0);
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
  if (this->policy == GemmWarpPolicy::kFullRow) {
    m_warp = num_warps;
    ICHECK(this->M % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
    n_warp = num_warps;
    ICHECK(this->N % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kSquare) {
    auto factors = toPrimeFactors(num_warps);
    for (int factor : factors) {
      bool M_divisible = (this->M % (factor * m_warp)) == 0;
      bool N_divisible = (this->N % (factor * n_warp)) == 0;
      if (M_divisible && N_divisible) {
        if (this->M / m_warp >= this->N / n_warp)
          m_warp *= factor;
        else
          n_warp *= factor;
      } else if (M_divisible) {
        m_warp *= factor;
      } else if (N_divisible) {
        n_warp *= factor;
      } else {
        ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
                  << " with num_warps " << num_warps;
      }
    }
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  // TODO: perform more checks here
  return {m_warp, n_warp};
}

112
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
113
114
115
116
117
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }

118
119
120
  bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
                     (T.block_size / warp_size % 4 == 0);

121
  auto [warp_m, warp_n] =
122
123
      ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);

124
125
126
127
128
129
130
131
132
133
134
135
136
137
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
138
139
  } else if (TargetIsHopper(T.target)) {
    ss << ", " << (maybe_wgmma ? "true" : "false");
140
  }
141
142
143
  if (wg_wait != 0) {
    ss << ", " << wg_wait;
  }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
  new_args.push_back(A_buffer.access_ptr(1));
  new_args.push_back(B_buffer.access_ptr(1));
  new_args.push_back(C_buffer.access_ptr(3));
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

158
159
160
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
161
162
163
164
165
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");

  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
166
167
168
169
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
170
171
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
172
173
174
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
                                           *as_const_int(A->shape[1]), true,
                                           trans_A ? 1 : 2));
175
176
177
178
179
180
181
182
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
      results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
183
184
185
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
                                         *as_const_int(B->shape[1]), false,
                                         trans_B ? 2 : 1));
186
187
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
188
189
190
191
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
192
193
194
    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
195
196
197
198
199
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      results.Set(A,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   A->dtype.bits(), trans_A ? 1 : 2));
200
201
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
202
203
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
204
205
206
207
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
208
209
210
211
212
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      results.Set(B,
                  makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
                                   B->dtype.bits(), trans_B ? 2 : 1));
213
214
215
216
217
218
219
220
    } else if (B.scope() == "local.fragment") {
      ICHECK(trans_B == false);
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
221
222
223
224
225
    bool maybe_wgmma = (this->M >= 64) && (T.block_size / warp_size % 4 == 0);
    if (!maybe_wgmma) {
      LOG(WARNING)
          << "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
    }
226
    auto [warp_m, warp_n] =
227
        ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);
228
    auto fragment =
229
230
231
232
        maybe_wgmma
            ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
                                      C->dtype.bits())
            : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
233
234
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
235
236
237
238
239
      const int64_t mat_stride = *as_const_int(A->shape[0]);
      const int64_t mat_continuous = *as_const_int(A->shape[1]);
      const int64_t continuity =
          trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
      results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
240
241
242
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else {
      ICHECK(trans_A == false);
243
244
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
245
246
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
247
248
249
250
251
      const int64_t mat_stride = *as_const_int(B->shape[0]);
      const int64_t mat_continuous = *as_const_int(B->shape[1]);
      const int64_t continuity =
          trans_B ? mat_continuous : mat_continuous / warp_n;
      results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
252
253
254
255
256
257
258
259
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";

    const int warp_size = 64;
260
261
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
262

263
264
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
265
266
267
268

    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
269

270
271
      // Make Linear Memory Access Layout
      // auto shared_layout =
272
273
      //     makeGemmLayoutLinear(*as_const_int(A->shape[0]),
      //     *as_const_int(A->shape[1]));
274
275

      // Make Swizzle or Pad Layout
276
277
278
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
                                                *as_const_int(A->shape[1]),
                                                A->dtype.bits(), kPack);
279
280
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
281
282
      results.Set(
          A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
283
284
285
286
287
288
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
      // Make Linear Memory Access Layout
      // auto shared_layout =
289
290
      //     makeGemmLayoutLinear(*as_const_int(B->shape[0]),
      //     *as_const_int(B->shape[1]));
291
292

      // Make Swizzle or Pad Layout
293
294
295
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
                                                *as_const_int(B->shape[1]),
                                                B->dtype.bits(), kPack);
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
312
313
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
314

315
316
} // namespace tl
} // namespace tvm