parallel.h 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "op.h"

namespace tvm {
namespace tl {

using namespace tir;

class ParallelOp;

class ParallelLoopNestVisitor : public StmtExprVisitor {
26
27
28
29
30
private:
  ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
  void VisitStmt_(const ForNode *op) final;
  void VisitStmt_(const BufferStoreNode *op) final;
  void VisitExpr_(const BufferLoadNode *op) final;
31

32
  ParallelOp *p;
33
34
35
36
37

  friend class ParallelOp;
};

class ParallelOp : public Operator {
38
public:
39
  ParallelOp(For root);
40
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
41
42
43
44
45
46

  Fragment GetLoopLayout() const { return loop_layout_; }
  For GetRoot() const { return root_; }
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

47
48
49
private:
  Fragment CompleteBufferFragment(const Buffer &buffer);
  bool IsCommonAccessIndice(const Buffer &buffer) const;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  void AddPredicate(PrimExpr expr) {
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }

  For root_;

  ParallelLoopNestVisitor V;

  Map<Buffer, Array<PrimExpr>> indice_map_;
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
  Array<IterVar> loop_vars_;

  Fragment loop_layout_;
  mutable arith::Analyzer analyzer_;
  Optional<PrimExpr> predicate_;

  friend class ParallelLoopNestVisitor;
};

69
70
} // namespace tl
} // namespace tvm
71

72
#endif // TVM_TL_OP_PARALLEL_H_