"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f5a19acfefdacb4f73a791390df8df964f6c08e1"
parallel.h 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/parallel.h
 * \brief Infer layout from ops and parallel for
 */

#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_

#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "op.h"

namespace tvm {
namespace tl {

using namespace tir;

class ParallelOp;

class ParallelLoopNestVisitor : public StmtExprVisitor {
 private:
  ParallelLoopNestVisitor(ParallelOp* op) : p(op){};
  void VisitStmt_(const ForNode* op) final;
  void VisitStmt_(const BufferStoreNode* op) final;
  void VisitExpr_(const BufferLoadNode* op) final;

  ParallelOp* p;

  friend class ParallelOp;
};

class ParallelOp : public Operator {
 public:
  ParallelOp(For root);
  LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;

  Fragment GetLoopLayout() const { return loop_layout_; }
  For GetRoot() const { return root_; }
  Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
  Optional<PrimExpr> GetPredicate(Var thread_var) const;

 private:
  Fragment CompleteBufferFragment(const Buffer& buffer);
  bool IsCommonAccessIndice(const Buffer& buffer) const;
  void AddPredicate(PrimExpr expr) {
    predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
  }

  For root_;

  ParallelLoopNestVisitor V;

  Map<Buffer, Array<PrimExpr>> indice_map_;
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
  Array<IterVar> loop_vars_;

  Fragment loop_layout_;
  mutable arith::Analyzer analyzer_;
  Optional<PrimExpr> predicate_;

  friend class ParallelLoopNestVisitor;
};

}  // namespace tl
}  // namespace tvm

#endif  // TVM_TL_OP_PARALLEL_H_