parallel.h 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "op.h"

namespace tvm {
namespace tl {

using namespace tir;

class ParallelOp;

class ParallelLoopNestVisitor : public StmtExprVisitor {
23
24
25
26
27
private:
  ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
  void VisitStmt_(const ForNode *op) final;
  void VisitStmt_(const BufferStoreNode *op) final;
  void VisitExpr_(const BufferLoadNode *op) final;
28

29
  ParallelOp *p;
30
31
32
33
34

  friend class ParallelOp;
};

class ParallelOp : public Operator {
35
public:
36
  ParallelOp(For root);
37
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
38
39
40
41
42
43

  Fragment GetLoopLayout() const { return loop_layout_; }
  For GetRoot() const { return root_; }
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

44
45
46
private:
  Fragment CompleteBufferFragment(const Buffer &buffer);
  bool IsCommonAccessIndice(const Buffer &buffer) const;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  void AddPredicate(PrimExpr expr) {
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }

  For root_;

  ParallelLoopNestVisitor V;

  Map<Buffer, Array<PrimExpr>> indice_map_;
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
  Array<IterVar> loop_vars_;

  Fragment loop_layout_;
  mutable arith::Analyzer analyzer_;
  Optional<PrimExpr> predicate_;

  friend class ParallelLoopNestVisitor;
};

66
67
} // namespace tl
} // namespace tvm
68

69
#endif // TVM_TL_OP_PARALLEL_H_