gemm.cc 9.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/gemm.cc
 *
 * Define gemm operator.
 */

#include "gemm.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

static std::vector<int> toPrimeFactors(int x) {
  int i = 2;
  std::vector<int> result;
  while (x > 1) {
    if (x % i == 0) {
      x /= i;
      result.push_back(i);
    } else {
      i++;
    }
  }
  return result;
}

Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
  A = vmap[GetVarFromAccessPtr(args[0])];
  B = vmap[GetVarFromAccessPtr(args[1])];
  C = vmap[GetVarFromAccessPtr(args[2])];
  trans_A = args[3].as<Bool>().value();
  trans_B = args[4].as<Bool>().value();
  M = args[5].as<IntImm>().value()->value;
  N = args[6].as<IntImm>().value()->value;
45
  K = args[7].as<IntImm>().value()->value;
46
47
48
49
50
51
52
53
54
  policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
  if (args.size() > 9) {
    kPack = args[9].as<IntImm>().value()->value;
    if (kPack != 1 && kPack != 2) {
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
}

55
56
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps,
                                               Target target) const {
57
58
59
  int m_warp = 1, n_warp = 1;
  if (TargetIsHopper(target)) {
    ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
60
61
    if (this->policy == GemmWarpPolicy::kFullRow ||
        this->policy == GemmWarpPolicy::kSquare) {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
      m_warp = num_warps;
      ICHECK(this->M % num_warps == 0);
    } else if (this->policy == GemmWarpPolicy::kFullCol) {
      m_warp = 4;
      n_warp = num_warps / 4;
      ICHECK(this->N % n_warp == 0);
    } else {
      ICHECK(0) << "Unknown GemmWarpPolicy";
    }
    return {m_warp, n_warp};
  }
  if (this->policy == GemmWarpPolicy::kFullRow) {
    m_warp = num_warps;
    ICHECK(this->M % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kFullCol) {
    n_warp = num_warps;
    ICHECK(this->N % num_warps == 0);
  } else if (this->policy == GemmWarpPolicy::kSquare) {
    auto factors = toPrimeFactors(num_warps);
    for (int factor : factors) {
      bool M_divisible = (this->M % (factor * m_warp)) == 0;
      bool N_divisible = (this->N % (factor * n_warp)) == 0;
      if (M_divisible && N_divisible) {
        if (this->M / m_warp >= this->N / n_warp)
          m_warp *= factor;
        else
          n_warp *= factor;
      } else if (M_divisible) {
        m_warp *= factor;
      } else if (N_divisible) {
        n_warp *= factor;
      } else {
        ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
                  << " with num_warps " << num_warps;
      }
    }
  } else {
    ICHECK(0) << "Unknown GemmWarpPolicy";
  }
  // TODO: perform more checks here
  return {m_warp, n_warp};
}

105
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
106
107
108
109
110
111
  int warp_size = 32;
  if (TargetIsCDNA(T.target)) {
    warp_size = 64;
  }

  ICHECK(T.block_size % warp_size == 0);
112
113
  auto [warp_m, warp_n] =
      ComputeWarpPartition(T.block_size / warp_size, T.target);
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  std::stringstream ss;
  std::string op_name = "tl::gemm_ss";
  if (A.scope() == "local.fragment") {
    ICHECK(B.scope() != "local.fragment");
    op_name = "tl::gemm_rs";
  } else if (B.scope() == "local.fragment") {
    op_name = "tl::gemm_sr";
  }
  ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
  ss << warp_m << ", " << warp_n << ", ";
  ss << trans_A << ", " << trans_B;
  if (TargetIsCDNA(T.target)) {
    // for cdna gemm, we need to specify kPack
    ss << ", " << kPack;
  }
  ss << ">";
  auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
  auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
  auto C_buffer = T.buffer_remap[C];

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm(ss.str()));
  new_args.push_back(A_buffer.access_ptr(1));
  new_args.push_back(B_buffer.access_ptr(1));
  new_args.push_back(C_buffer.access_ptr(3));
  auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
  return Evaluate(new_call);
}

143
144
145
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (completed_)
    return {};
146
147
148
149
150
  LayoutMap results;
  ICHECK(C.scope() == "local.fragment");

  if (TargetIsVolta(T.target)) {
    const int warp_size = 32;
151
152
153
154
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
155
156
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
157
158
159
      results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
                                           *as_const_int(A->shape[1]), true,
                                           trans_A ? 1 : 2));
160
161
162
163
164
165
166
167
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
      results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }

    ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
168
169
170
    results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
                                         *as_const_int(B->shape[1]), false,
                                         trans_B ? 2 : 1));
171
172
  } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
    const int warp_size = 32;
173
174
175
176
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
177
178
179
    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
180
181
      results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
                                      *as_const_int(A->shape[1]),
182
183
184
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else if (A.scope() == "local.fragment") {
      ICHECK(trans_A == false);
185
186
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
187
188
189
190
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
191
192
      results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
                                      *as_const_int(B->shape[1]),
193
194
195
196
197
198
199
200
201
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else if (B.scope() == "local.fragment") {
      ICHECK(trans_B == false);
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else if (TargetIsHopper(T.target)) {
    const int warp_size = 32;
202
203
204
205
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
    auto fragment =
        makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits());
206
207
    results.Set(C, fragment);
    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
208
209
      results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
                                      *as_const_int(A->shape[1]),
210
211
212
                                      A->dtype.bits(), trans_A ? 1 : 2));
    } else {
      ICHECK(trans_A == false);
213
214
      results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
                                       A->dtype.bits()));
215
216
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
217
218
      results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
                                      *as_const_int(B->shape[1]),
219
220
221
222
223
224
225
226
                                      B->dtype.bits(), trans_B ? 2 : 1));
    } else {
      ICHECK(0) << "WGMMA only support B in shared.";
    }
  } else if (TargetIsCDNA(T.target)) {
    ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";

    const int warp_size = 64;
227
228
    auto [warp_m, warp_n] =
        ComputeWarpPartition(T.block_size / warp_size, T.target);
229

230
231
    auto fragment =
        makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
232
233
234
235

    results.Set(C, fragment);

    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
236

237
238
      // Make Linear Memory Access Layout
      // auto shared_layout =
239
240
      //     makeGemmLayoutLinear(*as_const_int(A->shape[0]),
      //     *as_const_int(A->shape[1]));
241
242

      // Make Swizzle or Pad Layout
243
244
245
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
                                                *as_const_int(A->shape[1]),
                                                A->dtype.bits(), kPack);
246
247
      results.Set(A, shared_layout);
    } else if (A.scope() == "local.fragment") {
248
249
      results.Set(
          A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
250
251
252
253
254
255
    } else {
      ICHECK(0);
    }
    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
      // Make Linear Memory Access Layout
      // auto shared_layout =
256
257
      //     makeGemmLayoutLinear(*as_const_int(B->shape[0]),
      //     *as_const_int(B->shape[1]));
258
259

      // Make Swizzle or Pad Layout
260
261
262
      auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
                                                *as_const_int(B->shape[1]),
                                                B->dtype.bits(), kPack);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

      results.Set(B, shared_layout);
    } else if (B.scope() == "local.fragment") {
      results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
    } else {
      ICHECK(0);
    }
  } else {
    ICHECK(0) << "Not supported " << T.target->str();
  }
  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
    .set_num_inputs(5)
279
280
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
281

282
283
} // namespace tl
} // namespace tvm