"format/time.go" did not exist on "da7ddbb4dc856eeb4932cec95f355aa7b7fa6f49"
merge_if_stmt.cc 3.04 KB
Newer Older
1
2
3
4
5
/*!
 * \file if_stmt_binding.cc
 * \brief Merge the If Stmt in SeqStmt
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

class MergeIfStmtRewriter : public StmtExprMutator {
public:
  static PrimFunc Substitute(PrimFunc &f) {
    auto rewriter = MergeIfStmtRewriter();
    f.CopyOnWrite()->body = rewriter(f->body);
    return f;
  }

private:
  MergeIfStmtRewriter() = default;

  Stmt VisitStmt_(const SeqStmtNode *op) final {
    Array<Stmt> new_seq;

    PrimExpr current_condition;
    Array<Stmt> current_if_bodies;

    for (const Stmt &stmt : op->seq) {
      Stmt new_stmt = this->VisitStmt(stmt);
      if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
        if (!if_node->else_case.defined()) {
          if (current_condition.defined() &&
              StructuralEqual()(current_condition, if_node->condition)) {
            current_if_bodies.push_back(if_node->then_case);
            continue;
          } else {
            if (!current_if_bodies.empty()) {
47
48
49
50
51
52
53
              auto if_stmt =
                  IfThenElse(current_condition,
                             current_if_bodies.size() == 1
                                 ? current_if_bodies[0]
                                 : this->VisitStmt(SeqStmt(current_if_bodies)),
                             Stmt());
              new_seq.push_back(if_stmt);
54
55
56
57
58
59
60
61
62
63
64
              current_if_bodies.clear();
            }

            current_condition = if_node->condition;
            current_if_bodies.push_back(if_node->then_case);
            continue;
          }
        }
      }

      if (!current_if_bodies.empty()) {
65
66
67
68
69
70
71
        auto if_stmt =
            IfThenElse(current_condition,
                       current_if_bodies.size() == 1
                           ? current_if_bodies[0]
                           : this->VisitStmt(SeqStmt(current_if_bodies)),
                       Stmt());
        new_seq.push_back(if_stmt);
72
73
74
75
76
77
78
79
        current_condition = PrimExpr();
        current_if_bodies.clear();
      }

      new_seq.push_back(new_stmt);
    }

    if (!current_if_bodies.empty()) {
80
81
82
83
84
85
86
      auto if_stmt =
          IfThenElse(current_condition,
                     current_if_bodies.size() == 1
                         ? current_if_bodies[0]
                         : this->VisitStmt(SeqStmt(current_if_bodies)),
                     Stmt());
      new_seq.push_back(if_stmt);
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    }

    return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
  }
};

using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return MergeIfStmtRewriter::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}

101
102
103
104
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
});
105
106
107

} // namespace tl
} // namespace tvm