"test/vscode:/vscode.git/clone" did not exist on "e2b588c03a12152bb25567c8fcab78cbf1971bcd"
merge_if_stmt.cc 3.91 KB
Newer Older
1
2
3
4
5
/*!
 * \file if_stmt_binding.cc
 * \brief Merge the If Stmt in SeqStmt
 */

6
7
#include "merge_if_stmt.h"

8
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

class MergeIfStmtRewriter : public StmtExprMutator {
public:
  static PrimFunc Substitute(PrimFunc &f) {
25
    f.CopyOnWrite()->body = MergeIfStmtRewriter::Apply(f->body);
26
27
28
    return f;
  }

29
30
31
32
33
  static Stmt Apply(Stmt stmt) {
    auto rewriter = MergeIfStmtRewriter();
    return rewriter(stmt);
  }

34
35
36
private:
  MergeIfStmtRewriter() = default;

37
38
39
40
41
42
43
44
45
46
  void FlattenAppend(const Stmt &s, Array<Stmt> *out) {
    if (const auto *seq = s.as<SeqStmtNode>()) {
      for (const Stmt &e : seq->seq) {
        FlattenAppend(e, out);
      }
    } else {
      out->push_back(s);
    }
  }

47
  Stmt VisitStmt_(const SeqStmtNode *op) final {
48
49
50
51
52
53
54
55
    // First, recursively flatten nested SeqStmt so that
    //   SeqStmt{ if, SeqStmt{ if, SeqStmt{ if } } }
    // becomes a single-level sequence of [if, if, if].
    Array<Stmt> flat_seq;
    for (const Stmt &stmt : op->seq) {
      Stmt new_stmt = this->VisitStmt(stmt);
      FlattenAppend(new_stmt, &flat_seq);
    }
56

57
58
59
    // Then, merge consecutive IfThenElse (without else) that share the same
    // condition.
    Array<Stmt> new_seq;
60
61
62
    PrimExpr current_condition;
    Array<Stmt> current_if_bodies;

63
64
    for (const Stmt &stmt : flat_seq) {
      if (const auto *if_node = stmt.as<IfThenElseNode>()) {
65
66
        if (!if_node->else_case.defined()) {
          if (current_condition.defined() &&
67
              ExprDeepEqual()(current_condition, if_node->condition)) {
68
69
70
71
            current_if_bodies.push_back(if_node->then_case);
            continue;
          } else {
            if (!current_if_bodies.empty()) {
72
73
74
75
76
77
78
              auto if_stmt =
                  IfThenElse(current_condition,
                             current_if_bodies.size() == 1
                                 ? current_if_bodies[0]
                                 : this->VisitStmt(SeqStmt(current_if_bodies)),
                             Stmt());
              new_seq.push_back(if_stmt);
79
80
81
82
83
84
85
86
87
88
89
              current_if_bodies.clear();
            }

            current_condition = if_node->condition;
            current_if_bodies.push_back(if_node->then_case);
            continue;
          }
        }
      }

      if (!current_if_bodies.empty()) {
90
91
92
93
94
95
96
        auto if_stmt =
            IfThenElse(current_condition,
                       current_if_bodies.size() == 1
                           ? current_if_bodies[0]
                           : this->VisitStmt(SeqStmt(current_if_bodies)),
                       Stmt());
        new_seq.push_back(if_stmt);
97
98
99
100
        current_condition = PrimExpr();
        current_if_bodies.clear();
      }

101
      new_seq.push_back(stmt);
102
103
104
    }

    if (!current_if_bodies.empty()) {
105
106
107
108
109
110
111
      auto if_stmt =
          IfThenElse(current_condition,
                     current_if_bodies.size() == 1
                         ? current_if_bodies[0]
                         : this->VisitStmt(SeqStmt(current_if_bodies)),
                     Stmt());
      new_seq.push_back(if_stmt);
112
113
114
115
116
117
    }

    return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
  }
};

118
119
120
121
122
123
PrimFunc MergeIfStmtSubstitute(PrimFunc &f) {
  return MergeIfStmtRewriter::Substitute(f);
}

Stmt ApplyMergeIfStmt(Stmt stmt) { return MergeIfStmtRewriter::Apply(stmt); }

124
125
using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() {
126
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
127
128
129
130
131
    return MergeIfStmtRewriter::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}

132
TVM_FFI_STATIC_INIT_BLOCK() {
133
134
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
135
}
136
137
138

} // namespace tl
} // namespace tvm