layout.h 6.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 * \file Layout.h
 *
 */

#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_

#include <tvm/arith/analyzer.h>

namespace tvm {
namespace tl {

using namespace tir;

class Layout;
class Fragment;

class LayoutNode : public Object {
20
public:
21
22
23
24
25
26
27
28
29
30
31
32
33
  LayoutNode() = default;
  LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);

  size_t InputDim() const { return input_size_.size(); }

  size_t OutputDim() const { return forward_index_.size(); }

  Array<PrimExpr> InputShape() const { return input_size_; }

  Array<PrimExpr> OutputShape() const;

  Array<PrimExpr> GetForwardIndex() const { return forward_index_; }

34
35
  virtual Array<PrimExpr> GetForwardVars() const;

36
  virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
37
38
39

  virtual Layout Inverse() const;

40
41
42
  virtual std::string DebugOutput() const;

  virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
43
44

  static constexpr bool _type_has_method_sequal_reduce = true;
45
46
47
  static constexpr const char *_type_key = "tl.Layout";
  bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
  void VisitAttrs(tvm::AttrVisitor *v);
48
49
  TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);

50
protected:
51
  virtual Map<Var, Range> getVarMap() const;
52
  void UpdateAnalyzer(arith::Analyzer *analyzer) const;
53
54
55
56
57
58
59
60
  Array<PrimExpr> forward_index_;
  Array<PrimExpr> input_size_;
};

/*!
 * \brief Layout reference class.
 */
class Layout : public ObjectRef {
61
public:
62
63
64
65
66
67
68
  TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
  TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);

  TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
};

class FragmentNode : public LayoutNode {
69
public:
70
  FragmentNode() = default;
71
72
  FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
               PrimExpr forward_thread, PrimExpr replicate_size);
73
74
75

  PrimExpr GetForwardThread() const { return forward_thread_; }

76
77
  Array<PrimExpr> GetForwardVars() const final;

78
79
80
81
82
83
  Layout Inverse() const final;

  PrimExpr ThreadExtent() const;

  PrimExpr ReplicateExtent() const { return replicate_size_; };

84
85
  PrimExpr ForwardThread(const Array<PrimExpr> &vars,
                         const Optional<PrimExpr> &rep_var) const;
86

87
  Fragment Repeat(const Array<PrimExpr> &repeats, bool repeat_on_thread,
88
89
90
91
92
93
94
95
                  bool lower_dim_first = true) const;

  Fragment Replicate(int repeats) const;

  Fragment DeReplicate() const;

  Fragment CondenseReplicateVar() const;

96
97
  std::string DebugOutput() const final;

98
99
100
101
  Fragment SetThreadRange(Range thread_range);

  Range ThreadRange() const { return thread_range_; }

102
  bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
103

104
  void VisitAttrs(tvm::AttrVisitor *v);
105

106
107
  bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
  static constexpr const char *_type_key = "tl.Fragment";
108
109
  TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);

110
protected:
111
  Map<Var, Range> getVarMap() const final;
112
  Range thread_range_;
113
114
115
116
117
118
119
120
  PrimExpr forward_thread_;
  PrimExpr replicate_size_;
};

/*!
 * \brief Fragment reference class.
 */
class Fragment : public Layout {
121
public:
122
123
124
125
  TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate);

  TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
126
127
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var);
128
129

  TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
130
131
132
133
134
135

  Fragment SetThreadRange(Range thread_range) {
    auto node = make_object<FragmentNode>(*this->get());
    node->SetThreadRange(thread_range);
    return Fragment(node);
  }
136
137
138
139
140
141
142
};

Var InputPlaceholder(size_t idx);
Var ReplicationPlaceholder();

Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
143
144
Fragment makeGemmFragmentC(const int block_m, const int block_n,
                           const int warp_m, const int warp_n,
145
                           const int element_size);
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
                               const int warp_m, const int warp_n,
                               const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
                           const int warp_n, const int element_size);
Fragment makeGemmFragmentB(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
                           const int warp_n);

Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
                               const int block_k, const int warp_m,
161
162
                               const int warp_n, const int element_size,
                               bool transposed = false);
163
164
165
166

// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
167
168
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
                        int element_size, int kfactor);
169
170
171
172
173
174
175
176
177
178
179
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
                            int kfactor);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
                                const int warp_m, const int warp_n,
                                const int element_size);
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
                                const int block_k, const int warp_m,
                                const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
                             int kfactor);
180
181
182
183
184
185

Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);

namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
186
187
constexpr const char *kLayoutMap = "layout_map";
} // namespace attr
188

189
190
} // namespace tl
} // namespace tvm
191

192
#endif // TVM_TL_LAYOUT_LAYOUT_H_