"csrc/vscode:/vscode.git/clone" did not exist on "e38074b1e6ad0975acbfa15d858c4bd7cd005e99"
parallel.h 4.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
13
14
#include "../transform/layout_reducer.h"
#include "./operator.h"
15

16
/**
17
18
 * Conjoin `expr` into the operator's predicate (logical AND). If no predicate
 * exists yet, `expr` becomes the predicate.
19
 *
20
 * @param expr Predicate expression to add.
21
 */
22
23
24
25
26
namespace tvm {
namespace tl {

using namespace tir;

27
28
29
30
31
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_);

32
class ParallelOpNode;
33
34

class ParallelLoopNestVisitor : public StmtExprVisitor {
35
private:
36
  ParallelLoopNestVisitor(ParallelOpNode *op) : p(op) {};
37
38
39
  void VisitStmt_(const ForNode *op) override;
  void VisitStmt_(const BufferStoreNode *op) override;
  void VisitExpr_(const BufferLoadNode *op) override;
40

41
  ParallelOpNode *p;
42

43
  friend class ParallelOpNode;
44
45
};

46
47
48
49
// ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
50
public:
51
52
  // The root For loop node.
  For root_;
53
54
55
56
57
  // The inferred layout for the loop, mutable to allow lazy inference.
  mutable Fragment loop_layout_;
  // The predicate expression for the loop, if any, mutable for lazy
  // construction.
  mutable Optional<PrimExpr> predicate_;
58

59
  // Type key for TVM object system.
60
61
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode,
                                    TileOperatorNode);
62

63
64
65
66
67
68
69
70
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ParallelOpNode>()
        .def_ro("root", &ParallelOpNode::root_)
        .def_ro("loop_layout", &ParallelOpNode::loop_layout_)
        .def_ro("predicate", &ParallelOpNode::predicate_);
  }

71
72
73
74
75
76
77
78
79
80
81
82
  // Construct from a root For loop.
  ParallelOpNode(For root);

  // Lower the operator to a TIR statement.
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;

  // Infer the layout for this parallel operator.
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  // Copy constructor for ParallelOpNode.
  ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
83
84
85
86
    loop_layout_ = other.loop_layout_;
    predicate_ = other.predicate_;
  }

87
  // Get the inferred loop layout.
88
  Fragment GetLoopLayout() const { return loop_layout_; }
89
  // Get the root For loop.
90
  For GetRoot() const { return root_; }
91
  // Get the mapping from buffer to access indices.
92
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
93
  // Get the predicate for a given thread variable.
94
95
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

96
  // Clone this operator.
97
  TileOperator Clone() const override;
98

99
private:
100
101
102
  // Complete the fragment layout for a given buffer.
  Fragment CompleteBufferFragment(const Buffer &buffer) const;
  // Check if the buffer is accessed with common indices (i.e., loop variables).
103
  bool IsCommonAccessIndice(const Buffer &buffer) const;
104
  // Add a predicate to the current predicate expression.
105
  void AddPredicate(const PrimExpr &expr) const {
106
107
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }
108
109
110
111
  // Expand let bindings to find fragment buffer accesses and add them to
  // indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0],
  // ...)
  void ExpandLetBindings(const Map<Var, PrimExpr> &let_var_to_expr);
112

113
114
  // Allow ParallelLoopNestVisitor to access private members.
  friend class ParallelLoopNestVisitor;
115

116
  // Visitor for collecting loop nest information.
117
  ParallelLoopNestVisitor V;
118
  // Mapping from buffer to their access indices in the loop.
119
  Map<Buffer, Array<PrimExpr>> indice_map_;
120
  // Set of buffers that are written to in the loop.
121
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
122
  // The loop variables for the parallel loop nest.
123
  Array<IterVar> loop_vars_;
124
125
  // The inner_vars_
  Map<Var, IterVar> inner_vars_;
126
  // Analyzer for simplifying and analyzing expressions, mutable for lazy use.
127
  mutable arith::Analyzer analyzer_;
128
129
  // Mapping from buffer to reducer info.
  Map<Var, ReducerInfo> reducer_info_map_;
130
};
131

132
133
class ParallelOp : public TileOperator {
public:
134
135
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator,
                                             ParallelOpNode);
136

137
  ParallelOp(const For &root) {
138
    auto op = tvm::ffi::make_object<ParallelOpNode>(root);
139
140
    data_ = std::move(op);
  }
141
142
};

143
144
} // namespace tl
} // namespace tvm
145

146
#endif // TVM_TL_OP_PARALLEL_H_