gemm.h 4.54 KB
Newer Older
wangkx1's avatar
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/*!
 * \file tl/op/gemm.h
 * \brief Define gemm operator.
 *
 */

#ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_

#include "operator.h"

namespace tvm {

namespace tl {

using namespace tir;

enum class GemmWarpPolicyType : uint8_t {
  kSquare = 0,
  kFullRow = 1,
  kFullCol = 2,
  kFree = 3,
};

// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, KMMAC };
class GemmWarpPolicyNode : public Object {
public:
  mutable int m_warp{0};
  mutable int n_warp{0};
  int policy_type;

  TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<GemmWarpPolicyNode>()
        .def_ro("policy_type", &GemmWarpPolicyNode::policy_type)
        .def_ro("m_warp", &GemmWarpPolicyNode::m_warp)
        .def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
  }

  std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
                                           Target target,
                                           GemmInst gemm_inst) const;

  bool isSquare() const {
    return policy_type == int(GemmWarpPolicyType::kSquare);
  }
  bool isFullRow() const {
    return policy_type == int(GemmWarpPolicyType::kFullRow);
  }
  bool isFullCol() const {
    return policy_type == int(GemmWarpPolicyType::kFullCol);
  }
  bool isFree() const { return policy_type == int(GemmWarpPolicyType::kFree); }
};

class GemmWarpPolicy : public ObjectRef {
public:
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef,
                                             GemmWarpPolicyNode);

  explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
    auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
    node->policy_type = (int)policy_type;
    data_ = std::move(node);
  }

  explicit GemmWarpPolicy(int policy_type) {
    auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
    node->policy_type = policy_type;
    data_ = std::move(node);
  }

  explicit GemmWarpPolicy(int m_warp, int n_warp) {
    auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
    node->m_warp = m_warp;
    node->n_warp = n_warp;
    node->policy_type = (int)GemmWarpPolicyType::kFree;
    data_ = std::move(node);
  }
};

class GemmNode : public TileOperatorNode {
public:
  bool checkWgmma() const;
  tir::Buffer a_, b_, c_;
  // BufferRegion for A, B and C
  BufferRegion aRegion_, bRegion_, cRegion_;
  bool transA_, transB_;
  int m_, n_, k_;
  int strideA_, strideB_;
  int offsetA_, offsetB_;
  PrimExpr clearAccum_ = const_false();
  // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
  // only will be enabled under cdna mfma instructions
  int kPack_ = 1;
  int wgWait_ = 0;
  BufferRegion mbarRegion_;
  std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
  Array<PrimExpr> cCoords_;
  mutable GemmWarpPolicy policy_;
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<GemmNode>()
        .def_ro("a", &GemmNode::a_)
        .def_ro("b", &GemmNode::b_)
        .def_ro("c", &GemmNode::c_)
        .def_ro("aRegion", &GemmNode::aRegion_)
        .def_ro("bRegion", &GemmNode::bRegion_)
        .def_ro("cRegion", &GemmNode::cRegion_)
        .def_ro("transA", &GemmNode::transA_)
        .def_ro("transB", &GemmNode::transB_)
        .def_ro("m", &GemmNode::m_)
        .def_ro("n", &GemmNode::n_)
        .def_ro("k", &GemmNode::k_)
        .def_ro("strideA", &GemmNode::strideA_)
        .def_ro("strideB", &GemmNode::strideB_)
        .def_ro("offsetA", &GemmNode::offsetA_)
        .def_ro("offsetB", &GemmNode::offsetB_)
        .def_ro("clearAccum", &GemmNode::clearAccum_)
        .def_ro("kPack", &GemmNode::kPack_)
        .def_ro("wgWait", &GemmNode::wgWait_)
        .def_ro("policy", &GemmNode::policy_);
  }

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  TileOperator Clone() const;

private:
  GemmInst getGemmInst(int block_size, Target target) const;
  bool allowTcgen5Mma(Target target) const;
  bool allowWgmma(int block_size, Target target) const;

  mutable bool completed_ = false;
};

class Gemm : public TileOperator {
public:
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
  TVM_DLL Gemm(Array<PrimExpr> args);
  static const Op &Get();
};

} // namespace tl
} // namespace tvm

#endif //  TVM_TL_OP_GEMM_H_