"python/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "6fc175968c3a9fc0521948aa3636887cd6d84107"
parallel.h 7.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
13
14
#include "../transform/layout_reducer.h"
#include "./operator.h"
15

16
/**
17
18
 * Exception indicating a layout conflict during layout inference or validation.
 * The stored message is returned by what().
19
20
21
 */

/**
22
23
 * Verify that `small_frag` is contained within `large_frag` under the provided
 * index mappings and using symbolic reasoning via `analyzer_`.
24
 *
25
26
27
28
29
30
31
 * @param small_frag Fragment describing the smaller layout fragment.
 * @param large_frag Fragment describing the larger layout fragment.
 * @param small_frag_indices Index expressions that map accesses into
 * `small_frag`.
 * @param large_frag_indices Index expressions that map accesses into
 * `large_frag`.
 * @param analyzer_ Analyzer used for symbolic simplification and proving
32
 * relations.
33
34
 * @return true if `small_frag` can be proven to be contained in `large_frag`
 * given the index mappings and analyzer; false otherwise.
35
36
37
 */

/**
38
39
 * Visitor that traverses a parallel loop nest to collect loop structure,
 * buffer access patterns, and to populate the associated ParallelOpNode.
40
41
42
 */

/**
43
 * Construct a ParallelOpNode from a root For loop.
44
 *
45
 * @param root The TIR For node that is the root of the parallel loop nest.
46
47
48
 */

/**
49
 * Lower this ParallelOpNode to a TIR statement.
50
 *
51
52
 * Performs lowering of the operator (including any necessary predicates,
 * reductions, and loop transformations) to produce an equivalent tir::Stmt.
53
 *
54
55
 * @param T Lowering options and context.
 * @param analyzer Optional analyzer for symbolic simplification during
56
 * lowering.
57
 * @return A tir::Stmt representing the lowered operator.
58
59
60
 */

/**
61
 * Infer layouts for buffers used by this parallel operator.
62
 *
63
64
 * This performs layout inference at the requested level and returns a mapping
 * from buffers to their inferred layout fragments.
65
 *
66
67
68
 * @param T Layout inference arguments and context.
 * @param level Granularity level for inference.
 * @return LayoutMap mapping buffers to inferred fragments.
69
70
71
 */

/**
72
73
 * Return an optional predicate expression associated with the given thread
 * variable.
74
 *
75
76
77
 * If the loop nest imposes a condition on `thread_var` (e.g., bounds checks or
 * tiling edge predicates), this returns the combined predicate; otherwise
 * returns an empty Optional.
78
 *
79
80
 * @param thread_var The thread variable for which to retrieve the predicate.
 * @return Optional containing the predicate expression if present.
81
82
83
 */

/**
84
85
 * Create and return a clone of this operator as a TileOperator (deep copy of
 * operator state necessary for further transformations).
86
 *
87
 * @return A TileOperator referencing a cloned ParallelOpNode.
88
89
90
 */

/**
91
92
93
 * Complete the layout fragment for `buffer` by filling in any missing
 * dimension or stride information derived from access patterns in the loop
 * nest.
94
 *
95
96
 * @param buffer The buffer whose fragment should be completed.
 * @return A Fragment representing the completed layout for `buffer`.
97
98
99
 */

/**
100
101
 * Determine whether `buffer` is accessed using only the loop-common indices
 * (i.e., indices that correspond to the loop variables of this parallel nest).
102
 *
103
104
 * @param buffer The buffer to inspect.
 * @return true if accesses use common loop indices; false otherwise.
105
106
107
 */

/**
108
109
 * Conjoin `expr` into the operator's predicate (logical AND). If no predicate
 * exists yet, `expr` becomes the predicate.
110
 *
111
 * @param expr Predicate expression to add.
112
 */
113
114
115
116
117
namespace tvm {
namespace tl {

using namespace tir;

118
119
120
121
122
123
124
125
126
127
128
129
130
131
class LayoutConflictException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_);

132
class ParallelOpNode;
133
134

class ParallelLoopNestVisitor : public StmtExprVisitor {
135
private:
136
137
138
139
  ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){};
  void VisitStmt_(const ForNode *op) override;
  void VisitStmt_(const BufferStoreNode *op) override;
  void VisitExpr_(const BufferLoadNode *op) override;
140

141
  ParallelOpNode *p;
142

143
  friend class ParallelOpNode;
144
145
};

146
147
148
149
// ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
150
public:
151
152
153
154
155
  // The inferred layout for the loop, mutable to allow lazy inference.
  mutable Fragment loop_layout_;
  // The predicate expression for the loop, if any, mutable for lazy
  // construction.
  mutable Optional<PrimExpr> predicate_;
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  // Type key for TVM object system.
  static constexpr const char *_type_key = "tl.ParallelOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);

  // Construct from a root For loop.
  ParallelOpNode(For root);

  // Lower the operator to a TIR statement.
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;

  // Infer the layout for this parallel operator.
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  // Copy constructor for ParallelOpNode.
  ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
173
174
175
176
    loop_layout_ = other.loop_layout_;
    predicate_ = other.predicate_;
  }

177
  // Get the inferred loop layout.
178
  Fragment GetLoopLayout() const { return loop_layout_; }
179
  // Get the root For loop.
180
  For GetRoot() const { return root_; }
181
  // Get the mapping from buffer to access indices.
182
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
183
  // Get the predicate for a given thread variable.
184
185
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

186
  // Clone this operator.
187
  TileOperator Clone() const override;
188

189
private:
190
191
192
  // Complete the fragment layout for a given buffer.
  Fragment CompleteBufferFragment(const Buffer &buffer) const;
  // Check if the buffer is accessed with common indices (i.e., loop variables).
193
  bool IsCommonAccessIndice(const Buffer &buffer) const;
194
  // Add a predicate to the current predicate expression.
195
  void AddPredicate(const PrimExpr &expr) const {
196
197
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }
198
199
  // Allow ParallelLoopNestVisitor to access private members.
  friend class ParallelLoopNestVisitor;
200

201
  // The root For loop node.
202
  For root_;
203
  // Visitor for collecting loop nest information.
204
  ParallelLoopNestVisitor V;
205
  // Mapping from buffer to their access indices in the loop.
206
  Map<Buffer, Array<PrimExpr>> indice_map_;
207
  // Set of buffers that are written to in the loop.
208
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
209
  // The loop variables for the parallel loop nest.
210
  Array<IterVar> loop_vars_;
211
  // Analyzer for simplifying and analyzing expressions, mutable for lazy use.
212
  mutable arith::Analyzer analyzer_;
213
214
  // Mapping from buffer to reducer info.
  Map<Var, ReducerInfo> reducer_info_map_;
215
};
216

217
218
219
220
class ParallelOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);

221
  ParallelOp(const For &root) {
222
223
224
    auto op = make_object<ParallelOpNode>(root);
    data_ = std::move(op);
  }
225
226
};

227
228
} // namespace tl
} // namespace tvm
229

230
#endif // TVM_TL_OP_PARALLEL_H_