parallel.h 2.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "op.h"

namespace tvm {
namespace tl {

using namespace tir;

20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LayoutConflictException : public std::exception {
public:
  const char *what() const noexcept override { return msg_.c_str(); }
  LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
  std::string msg_;
};

bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_);

34
35
36
class ParallelOp;

class ParallelLoopNestVisitor : public StmtExprVisitor {
37
38
39
40
41
private:
  ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
  void VisitStmt_(const ForNode *op) final;
  void VisitStmt_(const BufferStoreNode *op) final;
  void VisitExpr_(const BufferLoadNode *op) final;
42

43
  ParallelOp *p;
44
45
46
47
48

  friend class ParallelOp;
};

class ParallelOp : public Operator {
49
public:
50
  ParallelOp(For root);
51
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
52

53
54
55
56
57
58
59
60
  ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) {
    loop_layout_ = other.loop_layout_;
    predicate_ = other.predicate_;
  }
  std::unique_ptr<Operator> Clone() const final {
    return std::make_unique<ParallelOp>(*this);
  }

61
62
63
64
65
  Fragment GetLoopLayout() const { return loop_layout_; }
  For GetRoot() const { return root_; }
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

66
67
68
private:
  Fragment CompleteBufferFragment(const Buffer &buffer);
  bool IsCommonAccessIndice(const Buffer &buffer) const;
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  void AddPredicate(PrimExpr expr) {
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }

  For root_;

  ParallelLoopNestVisitor V;

  Map<Buffer, Array<PrimExpr>> indice_map_;
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
  Array<IterVar> loop_vars_;

  Fragment loop_layout_;
  mutable arith::Analyzer analyzer_;
  Optional<PrimExpr> predicate_;

  friend class ParallelLoopNestVisitor;
};

88
89
} // namespace tl
} // namespace tvm
90

91
#endif // TVM_TL_OP_PARALLEL_H_