gemm.h 6.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * \file tl/op/gemm.h
 * \brief Define gemm operator.
 *
 */

#ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_

10
#include "operator.h"
11
12

namespace tvm {
13

14
15
16
17
namespace tl {

using namespace tir;

18
enum class GemmWarpPolicyType : uint8_t {
19
20
21
  kSquare = 0,
  kFullRow = 1,
  kFullCol = 2,
22
23
24
  kFree = 3,
};

25
26
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA };
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class GemmWarpPolicyNode : public Object {
public:
  mutable int m_warp{0};
  mutable int n_warp{0};
  int policy_type;

  static constexpr const char *_type_key = "tl.GemmWarpPolicy";
  TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<GemmWarpPolicyNode>()
        .def_ro("policy_type", &GemmWarpPolicyNode::policy_type)
        .def_ro("m_warp", &GemmWarpPolicyNode::m_warp)
        .def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
  }

  bool SEqualReduce(const GemmWarpPolicyNode *other,
                    SEqualReducer equal) const {
    return equal(policy_type, other->policy_type) &&
           equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(policy_type);
    hash_reduce(m_warp);
    hash_reduce(n_warp);
  }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

  std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
60
61
                                           Target target,
                                           GemmInst gemm_inst) const;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

  bool isSquare() const {
    return policy_type == int(GemmWarpPolicyType::kSquare);
  }
  bool isFullRow() const {
    return policy_type == int(GemmWarpPolicyType::kFullRow);
  }
  bool isFullCol() const {
    return policy_type == int(GemmWarpPolicyType::kFullCol);
  }
  bool isFree() const { return policy_type == int(GemmWarpPolicyType::kFree); }
};

class GemmWarpPolicy : public ObjectRef {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode);

  explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
    auto node = make_object<GemmWarpPolicyNode>();
    node->policy_type = (int)policy_type;
    data_ = std::move(node);
  }

  explicit GemmWarpPolicy(int policy_type) {
    auto node = make_object<GemmWarpPolicyNode>();
    node->policy_type = policy_type;
    data_ = std::move(node);
  }

  explicit GemmWarpPolicy(int m_warp, int n_warp) {
    auto node = make_object<GemmWarpPolicyNode>();
    node->m_warp = m_warp;
    node->n_warp = n_warp;
    node->policy_type = (int)GemmWarpPolicyType::kFree;
    data_ = std::move(node);
  }
98
};
99

100
101
class GemmNode : public TileOperatorNode {
public:
102
  bool CheckWGMMA() const;
103
  tir::Buffer A, B, C;
104
105
  // pointer to the A, B, C
  PrimExpr Aptr, Bptr, Cptr;
106
107
  bool trans_A, trans_B;
  int M, N, K;
108
109
  int stride_A, stride_B;
  int offset_A, offset_B;
110
  bool clear_accum = false;
111
112
  // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
  // only will be enabled under cdna mfma instructions
113
  int kPack = 1;
114
  int wg_wait = 0;
115
116
117
  PrimExpr mbarptr;
  std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
  Array<PrimExpr> C_coords;
118
  mutable GemmWarpPolicy policy;
119
120
121
122

  static constexpr const char *_type_key = "tl.Gemm";
  TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<GemmNode>()
        .def_ro("A", &GemmNode::A)
        .def_ro("B", &GemmNode::B)
        .def_ro("C", &GemmNode::C)
        .def_ro("Aptr", &GemmNode::Aptr)
        .def_ro("Bptr", &GemmNode::Bptr)
        .def_ro("Cptr", &GemmNode::Cptr)
        .def_ro("trans_A", &GemmNode::trans_A)
        .def_ro("trans_B", &GemmNode::trans_B)
        .def_ro("M", &GemmNode::M)
        .def_ro("N", &GemmNode::N)
        .def_ro("K", &GemmNode::K)
        .def_ro("stride_A", &GemmNode::stride_A)
        .def_ro("stride_B", &GemmNode::stride_B)
        .def_ro("offset_A", &GemmNode::offset_A)
        .def_ro("offset_B", &GemmNode::offset_B)
        .def_ro("clear_accum", &GemmNode::clear_accum)
        .def_ro("kPack", &GemmNode::kPack)
        .def_ro("wg_wait", &GemmNode::wg_wait)
        .def_ro("policy", &GemmNode::policy);
  }

  bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const {
    return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
           equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
           equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
           equal(trans_B, other->trans_B) && equal(M, other->M) &&
           equal(N, other->N) && equal(K, other->K) &&
           equal(stride_A, other->stride_A) &&
           equal(stride_B, other->stride_B) &&
155
           equal(offset_A, other->offset_A) &&
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
           equal(offset_B, other->offset_B) &&
           equal(clear_accum, other->clear_accum) &&
           equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
           equal(policy, other->policy);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(A);
    hash_reduce(B);
    hash_reduce(C);
    hash_reduce(Aptr);
    hash_reduce(Bptr);
    hash_reduce(Cptr);
    hash_reduce(trans_A);
    hash_reduce(trans_B);
    hash_reduce(M);
    hash_reduce(N);
    hash_reduce(K);
    hash_reduce(stride_A);
    hash_reduce(stride_B);
    hash_reduce(offset_A);
    hash_reduce(offset_B);
    hash_reduce(clear_accum);
    hash_reduce(kPack);
    hash_reduce(wg_wait);
    hash_reduce(policy);
  }
  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

186
187
188
189
190
191
192
193
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  TileOperator Clone() const;

private:
  GemmInst GetGemmInst(int block_size, Target target) const;
194
195
  bool AllowTCGEN5MMA(Target target) const;
  bool AllowWGMMA(int block_size, Target target) const;
196
197
198
199
200
201
202
203
204

  mutable bool completed_ = false;
};

class Gemm : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
  TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
205
206
};

207
208
} // namespace tl
} // namespace tvm
209

210
#endif //  TVM_TL_OP_GEMM_H_