merge_if_stmt.cc 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*!
 * \file if_stmt_binding.cc
 * \brief Merge the If Stmt in SeqStmt
 */

#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

class MergeIfStmtRewriter : public StmtExprMutator {
public:
  static PrimFunc Substitute(PrimFunc &f) {
    auto rewriter = MergeIfStmtRewriter();
    f.CopyOnWrite()->body = rewriter(f->body);
    return f;
  }

private:
  MergeIfStmtRewriter() = default;

  Stmt VisitStmt_(const SeqStmtNode *op) final {
    Array<Stmt> new_seq;

    PrimExpr current_condition;
    Array<Stmt> current_if_bodies;

    for (const Stmt &stmt : op->seq) {
      Stmt new_stmt = this->VisitStmt(stmt);
      if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
        if (!if_node->else_case.defined()) {
          if (current_condition.defined() &&
              StructuralEqual()(current_condition, if_node->condition)) {
            current_if_bodies.push_back(if_node->then_case);
            continue;
          } else {
            if (!current_if_bodies.empty()) {
              new_seq.push_back(IfThenElse(current_condition,
                                           current_if_bodies.size() == 1
                                               ? current_if_bodies[0]
                                               : SeqStmt(current_if_bodies),
                                           Stmt()));
              current_if_bodies.clear();
            }

            current_condition = if_node->condition;
            current_if_bodies.push_back(if_node->then_case);
            continue;
          }
        }
      }

      if (!current_if_bodies.empty()) {
        new_seq.push_back(IfThenElse(current_condition,
                                     current_if_bodies.size() == 1
                                         ? current_if_bodies[0]
                                         : SeqStmt(current_if_bodies),
                                     Stmt()));
        current_condition = PrimExpr();
        current_if_bodies.clear();
      }

      new_seq.push_back(new_stmt);
    }

    if (!current_if_bodies.empty()) {
      new_seq.push_back(IfThenElse(current_condition,
                                   current_if_bodies.size() == 1
                                       ? current_if_bodies[0]
                                       : SeqStmt(current_if_bodies),
                                   Stmt()));
    }

    return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
  }
};

using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return MergeIfStmtRewriter::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}

TVM_REGISTER_GLOBAL("tl.transform.MergeIfStmt").set_body_typed(MergeIfStmt);

} // namespace tl
} // namespace tvm