parallel.h 8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
13
14
#include "../transform/layout_reducer.h"
#include "./operator.h"
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/**
 * Exception representing a layout conflict detected during layout inference.
 *
 * Stores an explanatory message retrievable via what().
 */

/**
 * Determine whether `small_frag` is guaranteed to be contained within
 * `large_frag` under the given index mappings and using the provided arithmetic
 * analyzer.
 *
 * @param small_frag The smaller fragment to test for containment.
 * @param large_frag The larger fragment that may contain `small_frag`.
 * @param small_frag_indices Index expressions mapping the small fragment into
 * buffer space.
 * @param large_frag_indices Index expressions mapping the large fragment into
 * buffer space.
 * @param analyzer_ Arithmetic analyzer used to simplify and prove index
 * relations.
 * @return true if containment can be proven; false otherwise.
 */

/**
 * Visitor that traverses a parallel loop nest to collect buffer access and
 * loop-structure information for a ParallelOpNode.
 *
 * The visitor records loop variables, buffer read/write accesses, and builds
 * predicates as it encounters BufferLoad/BufferStore and For nodes.
 */

/**
 * Represents a parallel for-loop operator in TileLang.
 *
 * Holds the root For loop, collects and exposes loop layout and access-index
 * information, and provides layout inference and lowering to TIR.
 *
 * Public methods expose the inferred loop layout, root loop, buffer index
 * mappings, and any per-thread predicate; Lower and InferLayout perform the
 * operator's lowering and layout inference respectively.
 */

/**
 * Create a ParallelOpNode from a root For loop.
 *
 * @param root The root For node representing the parallel loop nest.
 */

/**
 * Lower this parallel operator into a TIR statement suitable for codegen.
 *
 * @param T Lowering arguments and context.
 * @param analyzer Arithmetic analyzer for expression simplification during
 * lowering.
 * @return A TIR statement representing the lowered parallel loop.
 */

/**
 * Infer the layout mapping for this parallel operator at the specified level.
 *
 * @param T Arguments and context for layout inference.
 * @param level Inference granularity level.
 * @return A LayoutMap describing inferred buffer/layout relationships for the
 * operator.
 */

/**
 * Copy-construct a ParallelOpNode, preserving inferred layout and predicate.
 */

/**
 * Get the inferred loop layout fragment.
 *
 * @return The Fragment representing the loop's inferred layout (may be lazily
 * computed).
 */

/**
 * Get the root For loop of this operator.
 *
 * @return The root For AST node.
 */

/**
 * Get the mapping from each buffer to the array of index expressions used to
 * access it within the loop nest.
 *
 * @return A Map from Buffer to Array<PrimExpr> of access indices.
 */

/**
 * Retrieve the predicate expression associated with a given thread variable, if
 * any.
 *
 * @param thread_var The thread variable whose predicate is requested.
 * @return An Optional<PrimExpr> containing the predicate when present.
 */

/**
 * Create a deep copy of this operator as a TileOperator handle.
 *
 * @return A TileOperator that references a copy of this node.
 */

/**
 * Visitor helper: complete the fragment layout for a buffer (internal).
 *
 * (Private helper — not part of the public API.)
 */

/**
 * Helper to check whether a buffer's access indices are the common loop indices
 * (internal).
 *
 * (Private helper — not part of the public API.)
 */

/**
 * Add `expr` to the current predicate by logical AND; sets predicate if none
 * exists.
 *
 * (Private helper — not part of the public API.)
 */

/**
 * Thin handle type exposing ParallelOpNode as a TileOperator.
 *
 * Construct from a root For loop to create and own a ParallelOpNode instance.
 */

/**
 * Construct a ParallelOp handle from a root For loop.
 *
 * @param root The root For node representing the parallel loop nest.
 */
150
151
152
153
154
namespace tvm {
namespace tl {

using namespace tir;

155
156
157
158
159
160
161
162
163
164
165
166
167
168
class LayoutConflictException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_);

169
class ParallelOpNode;
170
171

class ParallelLoopNestVisitor : public StmtExprVisitor {
172
private:
173
174
175
176
  ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){};
  void VisitStmt_(const ForNode *op) override;
  void VisitStmt_(const BufferStoreNode *op) override;
  void VisitExpr_(const BufferLoadNode *op) override;
177

178
  ParallelOpNode *p;
179

180
  friend class ParallelOpNode;
181
182
};

183
184
185
186
// ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
187
public:
188
189
190
191
192
  // The inferred layout for the loop, mutable to allow lazy inference.
  mutable Fragment loop_layout_;
  // The predicate expression for the loop, if any, mutable for lazy
  // construction.
  mutable Optional<PrimExpr> predicate_;
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
  // Type key for TVM object system.
  static constexpr const char *_type_key = "tl.ParallelOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);

  // Construct from a root For loop.
  ParallelOpNode(For root);

  // Lower the operator to a TIR statement.
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;

  // Infer the layout for this parallel operator.
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  // Copy constructor for ParallelOpNode.
  ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
210
211
212
213
    loop_layout_ = other.loop_layout_;
    predicate_ = other.predicate_;
  }

214
  // Get the inferred loop layout.
215
  Fragment GetLoopLayout() const { return loop_layout_; }
216
  // Get the root For loop.
217
  For GetRoot() const { return root_; }
218
  // Get the mapping from buffer to access indices.
219
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
220
  // Get the predicate for a given thread variable.
221
222
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

223
224
225
  // Clone this operator.
  TileOperator Clone() const;

226
private:
227
228
229
  // Complete the fragment layout for a given buffer.
  Fragment CompleteBufferFragment(const Buffer &buffer) const;
  // Check if the buffer is accessed with common indices (i.e., loop variables).
230
  bool IsCommonAccessIndice(const Buffer &buffer) const;
231
232
  // Add a predicate to the current predicate expression.
  void AddPredicate(PrimExpr expr) const {
233
234
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }
235
236
  // Allow ParallelLoopNestVisitor to access private members.
  friend class ParallelLoopNestVisitor;
237

238
  // The root For loop node.
239
  For root_;
240
  // Visitor for collecting loop nest information.
241
  ParallelLoopNestVisitor V;
242
  // Mapping from buffer to their access indices in the loop.
243
  Map<Buffer, Array<PrimExpr>> indice_map_;
244
  // Set of buffers that are written to in the loop.
245
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
246
  // The loop variables for the parallel loop nest.
247
  Array<IterVar> loop_vars_;
248
  // Analyzer for simplifying and analyzing expressions, mutable for lazy use.
249
  mutable arith::Analyzer analyzer_;
250
251
  // Mapping from buffer to reducer info.
  Map<Var, ReducerInfo> reducer_info_map_;
252
};
253

254
255
256
257
258
259
260
261
class ParallelOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);

  ParallelOp(For root) {
    auto op = make_object<ParallelOpNode>(root);
    data_ = std::move(op);
  }
262
263
};

264
265
} // namespace tl
} // namespace tvm
266

267
#endif // TVM_TL_OP_PARALLEL_H_