gemm_py.cc 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 * \file tl/op/gemm_py.cc
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
 */

#include "gemm_py.h"

#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"
15
#include "region.h"
16
#include "tcgen5_meta.h"
17
#include "utils.h"
18
19
20
21
22
23

namespace tvm {
namespace tl {

using namespace tir;

24
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
25

26
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmPyNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmPyNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 * @param vmap Mapping from access pointer vars to Buffer objects used to
 *   resolve the Buffer corresponding to each pointer argument.
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
57
  ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

  node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
77
  if (args.size() > 14) {
78
79
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
80
81
82
83
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
  if (args.size() > 15) {
84
    node->wgWait_ = args[15].as<IntImm>().value()->value;
85
  }
86
87
88
  node->mbarPtr_ = args[16];
  if (node->mbarPtr_.as<CallNode>()) {
    node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
89
  } else {
90
    node->mbar_ = std::nullopt;
91
  }
92
93
  node->cCoords_ = Array<PrimExpr>(
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
94
95
96
97
98
99
100
101
102
103
104
105
  data_ = std::move(node);
}

/**
 * @brief Create a copy of this GemmPyNode as a TileOperator.
 *
 * Constructs a new GemmPyNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
TileOperator GemmPyNode::Clone() const {
106
  auto op = tvm::ffi::make_object<GemmPyNode>(*this);
107
108
109
  return GemmPy(op);
}

110
bool GemmPyNode::allowTcgen5Mma(Target target) const {
111
  return TargetIsSm100(target) &&
112
113
114
115
116
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
117
118
}

119
bool GemmPyNode::allowWgmma(int block_size, Target target) const {
120
121
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

122
123
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
124
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
125
126
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
127
128
}

129
130
131
GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
  bool allow_tcgen5mma = allowTcgen5Mma(target);
  bool allow_wgmma = allowWgmma(block_size, target);
132
133
134
  if (allow_tcgen5mma) {
    return GemmInst::kTCGEN5MMA;
  } else if (allow_wgmma) {
135
136
137
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
138
139
140
  } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
             TargetIsTuring(target) || TargetIsHopper(target) ||
             TargetIsSm100(target)) {
141
142
143
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
144
145
    return GemmInst::kMMA; // This line will never be reached due to ICHECK, but
                           // satisfies compiler
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
  }
}

/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
179
180
bool GemmPyNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
181
182
183
    return false;
  }

184
185
186
187
188
189
190
191
192
193
194
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
195
196
    else
      return false;
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
214
215
    else
      return false;
216
217
218
219
220
221
222
223
224
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    else
      return false;
  } else {
    return false;
  }
}

/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
static int GetArchInt(Target target) {
  int arch_int = 0;
248
249
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
250
251
252
253
254
255
256
257
258
259
260
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
  } else {
    arch_int = 0;
  }
  return arch_int;
}

Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  auto block_size = *as_const_int(T.thread_bounds->extent);
261
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
262
263

  auto [warp_m, warp_n] =
264
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
265
266

  if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
267
    auto prim_func =
268
269
        Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmPy>(this), T.layout_map,
                                T.target, T.thread_bounds, T.thread_var));
270
    ICHECK(prim_func->attrs.defined());
271
272
273
    auto global_symbol =
        prim_func->attrs.GetAttr<tvm::ffi::String>("global_symbol");
    ICHECK(global_symbol.has_value());
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    if (prim_func->body.as<BlockRealizeNode>()) {
      BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
      auto block = block_realize->block;
      {
        BlockNode *n = block.CopyOnWrite();
        n->name_hint = global_symbol.value();
      }
      return BlockRealize(block_realize->iter_values, block_realize->predicate,
                          block);
    }
    // warp with block realize node
    return BlockRealize(
        /*iter_values=*/Array<PrimExpr>(),
        /*predicate=*/const_true(),
        /*block=*/
        Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
              /*name_hint=*/global_symbol.value(), prim_func->body));
  } else {
    LOG(FATAL) << "No lower function found for gemm_py";
293
294
    return Stmt(); // This line will never be reached due to LOG(FATAL), but
                   // satisfies compiler
295
296
297
298
299
300
301
302
303
304
305
  }
}

LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
                                  InferLevel level) const {
  if (completed_)
    return {};
  LayoutMap results;

  if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
    results = Downcast<LayoutMap>(
306
        (*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
307
308
309
310
311
312
313
314
    // Bind all fragment layouts with the provided thread range
    for (auto kv : results) {
      const Buffer &buf = kv.first;
      const Layout &layout = kv.second;
      if (auto frag = layout.as<Fragment>()) {
        results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
      }
    }
315
316
317
318
319
320
321
322
323
324
325
326
327
  } else {
    LOG(FATAL) << "No infer layout function found for gemm_py";
  }

  completed_ = true;
  return results;
}

TIR_REGISTER_TL_OP(GemmPy, gemm_py)
    .set_num_inputs(5)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

328
TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); }
329

330
TVM_FFI_STATIC_INIT_BLOCK() {
331
332
333
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmPyGemmInst",
                        [](GemmPy gemm_py, int block_size, Target target) {
334
                          return gemm_py->getGemmInst(block_size, target);
335
                        });
336
}
337

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def(
      "tl.get_tcgen5_mma_meta",
      [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
        auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
        Array<Integer> result;
        if (success) {
          result.push_back(Integer(meta.atom_m));
          result.push_back(Integer(meta.atom_n));
          result.push_back(Integer(meta.atom_k));
        }
        return result;
      });
  refl::GlobalDef().def(
      "tl.get_tcgen5_instr_desc",
      [](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
         DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
         int scale_in_b) {
        uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
                                           c_dtype, a_is_k_major, b_is_k_major,
                                           scale_in_a, scale_in_b);
        return Integer(static_cast<int64_t>(desc));
      });
}

364
365
} // namespace tl
} // namespace tvm