gemm_py.cc 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 * \file tl/op/gemm_py.cc
 * \brief Implementation of General Matrix Multiplication (GEMM) operators
 */

#include "gemm_py.h"

#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"
15
#include "tcgen5_meta.h"
16
#include "utils.h"
17
18
19
20
21
22

namespace tvm {
namespace tl {

using namespace tir;

23
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
24

25
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/**
 * @brief Construct a Gemm operator from serialized TL arguments and a buffer
 * map.
 *
 * This constructor deserializes operator parameters from `args` and resolves
 * buffer references via `vmap`, populating an internal GemmPyNode with:
 * - device pointers for A, B, C and their corresponding Buffer objects,
 * - transpose flags for A and B,
 * - matrix dimensions M, N, K,
 * - warp allocation policy and clear_accum flag,
 * - strides and memory offsets for A and B,
 * - optional kPack (must be 1 or 2) and optional wg_wait.
 *
 * The populated GemmPyNode is stored into the wrapper's internal `data_`.
 *
 * @param args Positional serialized arguments produced by the TL frontend:
 *   expected layout is:
 *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
 *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]
 *
 * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
 *       fails with an ICHECK (runtime assertion). No other validation is
 *       performed here.
 */
53
GemmPy::GemmPy(Array<PrimExpr> args) {
54
  ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
55

56
57
58
  node->aRegion_ = NormalizeToBufferRegion(args[0]);
  node->bRegion_ = NormalizeToBufferRegion(args[1]);
  node->cRegion_ = NormalizeToBufferRegion(args[2]);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

  node->a_ = node->aRegion_->buffer;
  node->b_ = node->bRegion_->buffer;
  node->c_ = node->cRegion_->buffer;
  node->transA_ = args[3].as<Bool>().value();
  node->transB_ = args[4].as<Bool>().value();
  node->m_ = args[5].as<IntImm>().value()->value;
  node->n_ = args[6].as<IntImm>().value()->value;
  node->k_ = args[7].as<IntImm>().value()->value;
  node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
  node->clearAccum_ = args[9].as<PrimExpr>().value();
  node->strideA_ = args[10].as<IntImm>().value()->value;
  node->strideB_ = args[11].as<IntImm>().value()->value;
  node->offsetA_ = args[12].as<IntImm>().value()->value;
  node->offsetB_ = args[13].as<IntImm>().value()->value;
74
  if (args.size() > 14) {
75
76
    node->kPack_ = args[14].as<IntImm>().value()->value;
    if (node->kPack_ != 1 && node->kPack_ != 2) {
77
78
79
80
      ICHECK(false) << "kPack must be 1 or 2";
    }
  }
  if (args.size() > 15) {
81
    node->wgWait_ = args[15].as<IntImm>().value()->value;
82
  }
83
84
85
86
87
88
  if (args.size() > 16) {
    if (const auto *load = args[16].as<BufferLoadNode>()) {
      node->mbarRegion_ =
          NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
      node->mbar_ = node->mbarRegion_->buffer;
    }
89
  }
90
91
  node->cCoords_ = Array<PrimExpr>(
      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
92
93
94
95
96
97
98
99
100
101
102
103
  data_ = std::move(node);
}

/**
 * @brief Create a copy of this GemmPyNode as a TileOperator.
 *
 * Constructs a new GemmPyNode by copying the current node state and returns it
 * wrapped in a Gemm TileOperator.
 *
 * @return TileOperator A Gemm operator that owns a copy of this node.
 */
TileOperator GemmPyNode::Clone() const {
104
  auto op = tvm::ffi::make_object<GemmPyNode>(*this);
105
106
107
  return GemmPy(op);
}

108
bool GemmPyNode::allowTcgen5Mma(Target target) const {
109
  return TargetIsSm100(target) &&
110
111
112
113
114
         ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
           a_.scope() == "shared.tmem") &&
          (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
          c_.scope() == "shared.tmem") &&
         GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
115
116
}

117
bool GemmPyNode::allowWgmma(int block_size, Target target) const {
118
119
  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

120
121
  int warp_size = TargetGetWarpSize(target);
  int num_warps = block_size / warp_size;
122
  return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
123
124
         TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
         checkWgmma();
125
126
}

127
128
129
GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
  bool allow_tcgen5mma = allowTcgen5Mma(target);
  bool allow_wgmma = allowWgmma(block_size, target);
130
131
132
  if (allow_tcgen5mma) {
    return GemmInst::kTCGEN5MMA;
  } else if (allow_wgmma) {
133
134
135
    return GemmInst::kWGMMA;
  } else if (TargetIsCDNA(target)) {
    return GemmInst::kMFMA;
136
137
138
  } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
             TargetIsTuring(target) || TargetIsHopper(target) ||
             TargetIsSm100(target)) {
139
140
141
    return GemmInst::kMMA;
  } else {
    ICHECK(0) << "Unsupported target for gemm: " << target->str();
142
143
    return GemmInst::kMMA; // This line will never be reached due to ICHECK, but
                           // satisfies compiler
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  }
}

/**
 * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
 *
 * Evaluates device-memory placement, data-type combinations, transpose flags,
 * and K divisibility constraints required for the Hopper WGMMA code path.
 *
 * The check returns true only when:
 * - B resides in shared memory ("shared" or "shared.dyn"); and
 * - (C, A, B) dtypes match one of the supported combinations below and K
 *   satisfies the required alignment; and
 * - for combinations that require specific orientations, A is not transposed
 *   and B is transposed.
 *
 * Supported combinations and constraints:
 * - C=float16:
 *   - A=float16, B=float16: K % 16 == 0
 *   - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
 * 32 == 0
 * - C=float32:
 *   - A=float16, B=float16: K % 16 == 0
 *   - A=bfloat16, B=bfloat16: K % 16 == 0
 *   - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
 *   - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
 * - C=int32:
 *   - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
 * and K % 32 == 0
 *
 * @return true if WGMMA is supported for the current buffers, dtypes, and
 *         transpose/shape constraints; false otherwise.
 */
177
178
bool GemmPyNode::checkWgmma() const {
  if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
179
180
181
    return false;
  }

182
183
184
185
186
187
188
189
190
191
192
  if (c_->dtype == DataType::Float(16)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
193
194
    else
      return false;
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
  } else if (c_->dtype == DataType::Float(32)) {
    if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::BFloat(16) &&
             b_->dtype == DataType::BFloat(16))
      return k_ % 16 == 0;
    else if (a_->dtype == DataType::Float(32) &&
             b_->dtype == DataType::Float(32))
      return (!transA_) && transB_ && k_ % 8 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
      return (!transA_) && transB_ && k_ % 32 == 0;
212
213
    else
      return false;
214
215
216
217
218
219
220
221
222
  } else if (c_->dtype == DataType::Int(32)) {
    if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
    else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
      return (!transA_) && transB_ && k_ % 32 == 0;
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    else
      return false;
  } else {
    return false;
  }
}

/**
 * @brief Parse and return the numeric GPU architecture from a Target's "arch"
 * attribute.
 *
 * Examines the target's "arch" string and, if it matches the pattern
 * "sm_<num>", returns <num> as an int. If the attribute is present but does not
 * match that pattern, returns 0.
 *
 * Preconditions: the target must have an "arch" attribute (this is checked via
 * ICHECK).
 *
 * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
 * the arch string does not match "sm_<num>".
 */
static int GetArchInt(Target target) {
  int arch_int = 0;
246
247
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
248
249
250
251
252
253
254
255
256
257
258
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
  } else {
    arch_int = 0;
  }
  return arch_int;
}

Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  auto block_size = *as_const_int(T.thread_bounds->extent);
259
  GemmInst gemm_inst = getGemmInst(block_size, T.target);
260
261

  auto [warp_m, warp_n] =
262
      policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
263
264

  if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
265
    auto prim_func =
266
267
        Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmPy>(this), T.layout_map,
                                T.target, T.thread_bounds, T.thread_var));
268
    ICHECK(prim_func->attrs.defined());
269
270
271
    auto global_symbol =
        prim_func->attrs.GetAttr<tvm::ffi::String>("global_symbol");
    ICHECK(global_symbol.has_value());
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    if (prim_func->body.as<BlockRealizeNode>()) {
      BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
      auto block = block_realize->block;
      {
        BlockNode *n = block.CopyOnWrite();
        n->name_hint = global_symbol.value();
      }
      return BlockRealize(block_realize->iter_values, block_realize->predicate,
                          block);
    }
    // warp with block realize node
    return BlockRealize(
        /*iter_values=*/Array<PrimExpr>(),
        /*predicate=*/const_true(),
        /*block=*/
        Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
              /*name_hint=*/global_symbol.value(), prim_func->body));
  } else {
    LOG(FATAL) << "No lower function found for gemm_py";
291
292
    return Stmt(); // This line will never be reached due to LOG(FATAL), but
                   // satisfies compiler
293
294
295
296
297
298
299
300
301
302
303
  }
}

LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
                                  InferLevel level) const {
  if (completed_)
    return {};
  LayoutMap results;

  if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
    results = Downcast<LayoutMap>(
304
        (*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
305
306
307
308
309
310
311
312
    // Bind all fragment layouts with the provided thread range
    for (auto kv : results) {
      const Buffer &buf = kv.first;
      const Layout &layout = kv.second;
      if (auto frag = layout.as<Fragment>()) {
        results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
      }
    }
313
314
315
316
317
318
319
320
  } else {
    LOG(FATAL) << "No infer layout function found for gemm_py";
  }

  completed_ = true;
  return results;
}

321
TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
322
323
324
325
    .set_num_inputs(5)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

326
TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); }
327

328
TVM_FFI_STATIC_INIT_BLOCK() {
329
330
331
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.GemmPyGemmInst",
                        [](GemmPy gemm_py, int block_size, Target target) {
332
                          return gemm_py->getGemmInst(block_size, target);
333
                        });
334
}
335

336
337
338
339
340
341
342
343
344
345
346
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def(
      "tl.get_tcgen5_mma_meta",
      [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
        auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
        Array<Integer> result;
        if (success) {
          result.push_back(Integer(meta.atom_m));
          result.push_back(Integer(meta.atom_n));
          result.push_back(Integer(meta.atom_k));
347
348
          result.push_back(Integer(meta.enable_ws));
          result.push_back(Integer(meta.enable_2cta));
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        }
        return result;
      });
  refl::GlobalDef().def(
      "tl.get_tcgen5_instr_desc",
      [](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
         DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
         int scale_in_b) {
        uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
                                           c_dtype, a_is_k_major, b_is_k_major,
                                           scale_in_a, scale_in_b);
        return Integer(static_cast<int64_t>(desc));
      });
}

364
365
} // namespace tl
} // namespace tvm