parallel.h 4.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
13
#include "operator.h"
14
15
16
17
18
19

namespace tvm {
namespace tl {

using namespace tir;

20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LayoutConflictException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_);

34
class ParallelOpNode;
35
36

class ParallelLoopNestVisitor : public StmtExprVisitor {
37
private:
38
39
40
41
  ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){};
  void VisitStmt_(const ForNode *op) override;
  void VisitStmt_(const BufferStoreNode *op) override;
  void VisitExpr_(const BufferLoadNode *op) override;
42

43
  ParallelOpNode *p;
44

45
  friend class ParallelOpNode;
46
47
};

48
49
50
51
// ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
52
public:
53
54
55
56
57
  // The inferred layout for the loop, mutable to allow lazy inference.
  mutable Fragment loop_layout_;
  // The predicate expression for the loop, if any, mutable for lazy
  // construction.
  mutable Optional<PrimExpr> predicate_;
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  // Type key for TVM object system.
  static constexpr const char *_type_key = "tl.ParallelOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);

  // Construct from a root For loop.
  ParallelOpNode(For root);

  // Lower the operator to a TIR statement.
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;

  // Infer the layout for this parallel operator.
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  // Copy constructor for ParallelOpNode.
  ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
75
76
77
78
    loop_layout_ = other.loop_layout_;
    predicate_ = other.predicate_;
  }

79
  // Get the inferred loop layout.
80
  Fragment GetLoopLayout() const { return loop_layout_; }
81
  // Get the root For loop.
82
  For GetRoot() const { return root_; }
83
  // Get the mapping from buffer to access indices.
84
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
85
  // Get the predicate for a given thread variable.
86
87
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

88
89
90
  // Clone this operator.
  TileOperator Clone() const;

91
private:
92
93
94
  // Complete the fragment layout for a given buffer.
  Fragment CompleteBufferFragment(const Buffer &buffer) const;
  // Check if the buffer is accessed with common indices (i.e., loop variables).
95
  bool IsCommonAccessIndice(const Buffer &buffer) const;
96
97
  // Add a predicate to the current predicate expression.
  void AddPredicate(PrimExpr expr) const {
98
99
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }
100
101
  // Allow ParallelLoopNestVisitor to access private members.
  friend class ParallelLoopNestVisitor;
102

103
  // The root For loop node.
104
  For root_;
105
  // Visitor for collecting loop nest information.
106
  ParallelLoopNestVisitor V;
107
  // Mapping from buffer to their access indices in the loop.
108
  Map<Buffer, Array<PrimExpr>> indice_map_;
109
  // Set of buffers that are written to in the loop.
110
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
111
  // The loop variables for the parallel loop nest.
112
  Array<IterVar> loop_vars_;
113
  // Analyzer for simplifying and analyzing expressions, mutable for lazy use.
114
  mutable arith::Analyzer analyzer_;
115
};
116

117
118
119
120
121
122
123
124
class ParallelOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);

  ParallelOp(For root) {
    auto op = make_object<ParallelOpNode>(root);
    data_ = std::move(op);
  }
125
126
};

127
128
} // namespace tl
} // namespace tvm
129

130
#endif // TVM_TL_OP_PARALLEL_H_