inject_assumes.cc 5.31 KB
Newer Older
1
2
3
4
5
/*!
 * \file inject_assumes.cc
 * \brief Inject assumes on buffer's shape boundary check. Also convert
 * existing assumes to AttrNodes.
 */
6

7
#include "common/assume.h"
8
9
10
11
12
13
14
#include "tvm/arith/analyzer.h"
#include "tvm/ffi/optional.h"
#include "tvm/ir/expr.h"
#include "tvm/ir/transform.h"
#include "tvm/node/structural_hash.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
15
#include "tvm/tir/op.h"
16
17
18
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h"
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <sstream>

namespace tvm::tl {
using namespace tir;

class AssumeInjector : public tvm::tir::StmtExprMutator {
  using Base = tvm::tir::StmtExprMutator;

public:
  AssumeInjector(PrimFunc f) : f(f) {}
  static PrimFunc Substitute(PrimFunc f) {
    auto injector = AssumeInjector(f);
    f.CopyOnWrite()->body = injector(f->body);
    return f;
  }

private:
37
  struct AssumeCreator {
38
39
40
41
    struct Item {
      PrimExpr expr;
      std::vector<Buffer> buffers;
    };
42

43
44
    tvm::StructuralHash sh;
    tvm::StructuralEqual se;
45
46
    // grouped by expr, since the amount of variadic shape symbols is usually
    // much smaller than buffer
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    std::vector<Item> items;
    // hash => index in items
    std::unordered_map<size_t, std::vector<size_t>> buckets;
    void addExpr(PrimExpr e, Buffer buffer) {
      size_t h = sh(e);
      auto &bucket = buckets[h];
      auto it = std::find_if(bucket.begin(), bucket.end(), [&](size_t y) {
        return se(e, items[y].expr, true);
      });
      if (it == bucket.end()) {
        auto index = items.size();
        items.push_back({e, {buffer}});
        bucket.push_back(index);
      } else {
        items[*it].buffers.push_back(buffer);
      }
    }
64

65
66
67
68
69
70
71
    void addBuffer(Buffer buf) {
      for (auto shape : buf->shape) {
        if (shape->IsInstance<IntImmNode>())
          continue;
        addExpr(shape, buf);
      }
    }
72

73
74
75
    Stmt build(Stmt body) {
      auto analyzer = arith::Analyzer{};
      for (const auto &e : items) {
76
77
        auto simplified =
            analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype)));
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        std::stringstream ss;
        ss << "Buffer shape should be greater than 0: shape `" << e.expr
           << "` from buffer ";
        for (size_t i = 0; i < e.buffers.size(); i++) {
          if (i)
            ss << ", ";
          ss << "`" << e.buffers[i]->name << "`";
        }
        body = AttrStmt(simplified, tir::attr::tilelang_assume,
                        StringImm(ss.str()), body);
      }
      return body;
    }
  };
92

93
94
  Stmt VisitStmt_(const DeclBufferNode *op) final {
    auto body = VisitStmt(op->body);
95
    AssumeCreator c;
96
97
98
    c.addBuffer(op->buffer);
    return DeclBuffer(op->buffer, c.build(body), op->span);
  }
99

100
101
102
103
104
105
  Stmt VisitStmt_(const SeqStmtNode *op) final {
    struct AssumeGroup {
      std::optional<PrimExpr> e;
      std::vector<Stmt> stmts;
    };
    std::vector<AssumeGroup> groups = {AssumeGroup{std::nullopt, {}}};
106
    for (size_t i = 0; i < op->seq.size(); i++) {
107
      auto stmt = VisitStmt(op->seq[i]);
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
      // Convert assume in evaluate form to assume attribute.
      // By default, we have the following IR:
      //    T.assume(cond1)
      //    Stmt1
      //    Stmt2
      //    T.assume(cond2)
      // This SeqStmt will be converted to:
      //    With(attr::tilelang_assume, cond1) {
      //      Stmt1
      //      Stmt2
      //    }
      //    With(attr::tilelang_assume, cond2) {
      //      ...
      //    }
      if (auto e = GetAssumeExprInEvaluateForm(stmt)) {
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        groups.push_back(AssumeGroup{*e, {}});
      } else {
        groups.back().stmts.push_back(stmt);
      }
    }
    for (size_t i = groups.size(); i--;) {
      auto &g = groups[i];
      if (g.e) {
        Stmt body = g.stmts.size() == 1 ? g.stmts[0] : SeqStmt(g.stmts);
        std::stringstream ss;
        ss << "Assume: " << *(g.e);
        AttrStmt attr = AttrStmt(*g.e, tir::attr::tilelang_assume,
                                 StringImm(ss.str()), body);
        groups[i - 1].stmts.push_back(attr);
      } else {
        ICHECK(i == 0) << "only the first group can have no assume";
      }
    }
    return groups[0].stmts.size() == 1 ? groups[0].stmts[0]
                                       : SeqStmt(groups[0].stmts);
    // return SeqStmt(groups[0].stmts);
  }
145

146
147
  Stmt VisitStmt_(const BlockNode *op) final {
    auto body = VisitStmt(op->body);
148
149
150
151
152
    AssumeCreator c;

    // NOTE(chaofan): We only inject assumes from function arguments in the
    // root block.
    if (op->name_hint == "root") {
153
154
155
156
157
158
159
160
161
162
      for (auto item : f->buffer_map) {
        c.addBuffer(item.second);
      }
    }
    for (auto item : op->alloc_buffers) {
      c.addBuffer(item);
    }
    for (auto item : op->match_buffers) {
      c.addBuffer(item->buffer);
    }
163

164
165
166
167
    return Block(op->iter_vars, op->reads, op->writes, op->name_hint,
                 c.build(body), op->init, op->alloc_buffers, op->match_buffers,
                 op->annotations, op->span);
  }
168

169
170
171
172
173
174
175
176
177
178
179
180
  PrimFunc f;
};

using namespace tir::transform;

tvm::transform::Pass InjectAssumes() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return AssumeInjector::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {});
}

181
TVM_FFI_STATIC_INIT_BLOCK() {
182
183
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes);
184
}
185
186

} // namespace tvm::tl