layout.h 8.94 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file Layout.h
 *
 */

#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_

9
#include <exception>
10
#include <tvm/arith/analyzer.h>
11
#include <tvm/arith/iter_affine_map.h>
12
#include <tvm/ffi/object.h>
13
#include <utility>
14

15
16
#include "../support/ffi_aliases.h"

17
18
19
20
21
namespace tvm {
namespace tl {

using namespace tir;

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

class LoopLayoutInjectiveException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

41
42
43
44
class Layout;
class Fragment;

class LayoutNode : public Object {
45
public:
46
47
48
49
50
51
52
53
54
55
56
57
58
  LayoutNode() = default;
  LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);

  size_t InputDim() const { return input_size_.size(); }

  size_t OutputDim() const { return forward_index_.size(); }

  Array<PrimExpr> InputShape() const { return input_size_; }

  Array<PrimExpr> OutputShape() const;

  Array<PrimExpr> GetForwardIndex() const { return forward_index_; }

59
60
  virtual Array<PrimExpr> GetForwardVars() const;

61
  virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
62
63

  virtual Layout Inverse() const;
64

65
66
67
68
69
70
71
72
  // Reshape the layout to a new logical shape. When aliasing buffers of
  // different dtypes, the element count may change while the underlying
  // byte-size stays equal. Use rescale_num/rescale_den to represent the
  // ratio between the old element size and the new element size in bytes.
  // Specifically, define factor = rescale_num / rescale_den where:
  //   new_num_elems = old_num_elems * factor
  // For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
  // i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
73
  virtual Layout Reshape(const Array<PrimExpr> &shape,
74
75
76
                         arith::Analyzer *analyzer,
                         const PrimExpr rescale_num = Integer(1),
                         const PrimExpr rescale_den = Integer(1)) const;
77

78
  virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
79

80
81
82
  virtual std::string DebugOutput() const;

  virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
83

84
  static void RegisterReflection();
85
86
87
  TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
      kTVMFFISEqHashKindTreeNode;
88

89
protected:
90
  virtual Map<Var, Range> getVarMap() const;
91
  void UpdateAnalyzer(arith::Analyzer *analyzer) const;
92
93
94
95
96
97
98
99
  Array<PrimExpr> forward_index_;
  Array<PrimExpr> input_size_;
};

/*!
 * \brief Layout reference class.
 */
class Layout : public ObjectRef {
100
public:
101
102
103
  TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
  TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);

104
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
105
106
107
};

class FragmentNode : public LayoutNode {
108
public:
109
  FragmentNode() = default;
110
111
  FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
               PrimExpr forward_thread, PrimExpr replicate_size);
112
113
114

  PrimExpr GetForwardThread() const { return forward_thread_; }

115
116
  Array<PrimExpr> GetForwardVars() const final;

117
  Layout Inverse() const final;
118

119
120
121
  Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
                 const PrimExpr rescale_num = Integer(1),
                 const PrimExpr rescale_den = Integer(1)) const;
122

123
  std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
124
125
126
127
128

  PrimExpr ThreadExtent() const;

  PrimExpr ReplicateExtent() const { return replicate_size_; };

129
130
  PrimExpr ForwardThread(const Array<PrimExpr> &vars,
                         const Optional<PrimExpr> &rep_var) const;
131

132
  Fragment Repeat(const Array<PrimExpr> &repeats, bool repeat_on_thread,
133
134
135
136
137
138
139
140
                  bool lower_dim_first = true) const;

  Fragment Replicate(int repeats) const;

  Fragment DeReplicate() const;

  Fragment CondenseReplicateVar() const;

141
142
  std::string DebugOutput() const final;

143
  Fragment BindThreadRange(Range thread_range) const;
144
145
146

  Range ThreadRange() const { return thread_range_; }

147
  bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
148

149
150
  bool IsCompletedReplicated() const;

151
152
  arith::IterMapResult DetectInjective() const;

153
  static void RegisterReflection();
154

155
156
157
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
      kTVMFFISEqHashKindTreeNode;
158

159
protected:
160
  Map<Var, Range> getVarMap() const final;
161
  Range thread_range_;
162
163
164
165
166
167
168
169
  PrimExpr forward_thread_;
  PrimExpr replicate_size_;
};

/*!
 * \brief Fragment reference class.
 */
class Fragment : public Layout {
170
public:
171
172
173
174
  TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate);

  TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
175
176
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var);
177

178
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
179
180
181
182
};

Var InputPlaceholder(size_t idx);
Var ReplicationPlaceholder();
183
IterVar make_itervar(std::string name, PrimExpr dom);
184
185
186

Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
187
188
Fragment makeGemmFragmentC(const int block_m, const int block_n,
                           const int warp_m, const int warp_n,
189
                           const int element_size);
190
191
192
Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size);
193
194
195
196
197
198
199
200
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
                               const int warp_m, const int warp_n,
                               const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
201
202
                           const int warp_n, const int element_size,
                           bool transposed = false);
203
204
Fragment makeGemmFragmentB(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
205
                           const int warp_n, bool transposed = false);
206
207
208

Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
                               const int block_k, const int warp_m,
209
                               const int warp_n, const int element_size,
210
                               const int k_pack, bool transposed = false);
211
212
213
214

// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
215
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
216
                        int element_size, bool k_inner = true);
217
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
218
219
                              int continuity, int element_size,
                              bool k_inner = true);
220
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
221
                             int element_size, bool k_inner = true);
222
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
223
                            int kPack);
224
225
226
227
228
229
230
231

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
                                const int warp_m, const int warp_n,
                                const int element_size);
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
                                const int block_k, const int warp_m,
                                const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
232
                             bool k_inner = true);
233

234
235
236
237
238
Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
                                int elementsize, int crosswise);
Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
                                    int elementsize);

239
240
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
241
242
Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
                                    int element_size);
243
244
245

namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
246
247
constexpr const char *kLayoutMap = "layout_map";
} // namespace attr
248

249
250
} // namespace tl
} // namespace tvm
251

252
#endif // TVM_TL_LAYOUT_LAYOUT_H_