simplify.cc 15.4 KB
Newer Older
1
2
3
4
5
/*!
 * \file simplify.cc
 * \brief Remove useless parameters of TL PrimFunc.
 */

6
#include <tvm/tir/buffer.h>
7
8
9
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
10
#include <tvm/tir/utils.h>
11
12
13

#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
14
#include "tir/analysis/var_use_def_analysis.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
  bool transitively_prove_inequalities;
  bool propagate_knowns_to_prove_conditional;
  bool propagate_knowns_to_simplify_expressions;
  bool convert_boolean_to_and_of_ors;
  bool apply_constraints_to_boolean_branches;

  TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
    TVM_ATTR_FIELD(transitively_prove_inequalities)
31
32
        .describe("If true, simplify conditionals with transitive combinations "
                  "of scoped constraints")
33
34
35
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
36
37
        .describe("If true, known buffer values are propagated and used to "
                  "statically prove conditionals")
38
39
40
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
41
42
43
        .describe("If true, known buffer values are propagated and used to "
                  "replace BufferLoad wherever "
                  "possible")
44
45
46
47
48
49
50
        .set_default(false);

    TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
        .describe("If true, simplify conditionals into an AND of ORs")
        .set_default(false);

    TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
51
52
        .describe("If true, simplify each branch of AND/OR "
                  "under a constraints provided by the other branch")
53
54
55
        .set_default(false);
  }

56
  RewriteSimplifier::Extension GetEnabledExtensions() const {
57
58
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
59
60
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
61
62
    }
    if (convert_boolean_to_and_of_ors) {
63
64
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
65
66
    }
    if (apply_constraints_to_boolean_branches) {
67
68
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
69
70
71
72
73
    }
    return flags;
  }
};

74
75
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
76
77
78
79
80
81
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

    Visitor(PrimFunc func) : func(func) {}

82
83
84
85
86
87
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
88
        }
89
90
      }
      StmtExprVisitor::VisitExpr_(op);
91
    }
92
    void VisitExpr_(const BufferLoadNode *op) override {
93
94
95
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
96
    void VisitStmt_(const BufferStoreNode *op) override {
97
98
99
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
100
101
102
103
104
105
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
106
        }
107
108
109
110
111
112
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
113
        }
114
115
116
117
118
119
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
120
        }
121
122
      }
      StmtExprVisitor::VisitStmt_(op);
123
124
    }

125
    void VisitBuffer(const Buffer &buf) {
126
127
128
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
129
      for (const auto &dim : buf->shape) {
130
131
        usage(dim);
      }
132
      for (const auto &dim : buf->strides) {
133
134
135
136
        usage(dim);
      }
      usage(buf->elem_offset);

137
      for (const auto &buffer : usage.buffer_use_count_) {
138
        if (buffer.second >= 1) {
139
          used_in_buffer_def_.insert(buffer.first);
140
141
        }
      }
142
      for (const auto &buffer : usage.undefined_buffers_) {
143
144
145
146
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
147
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
148
149
150
151
152
153
154
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

155
156
157
158
159
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
160
161
162
163
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

164
    void VisitExpr_(const BufferLoadNode *op) override {
165
166
167
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
168
    void VisitStmt_(const BufferStoreNode *op) override {
169
170
171
172
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

173
    void VisitBuffer(const Buffer &buf) {
174
175
176
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
177
      for (const auto &dim : buf->shape) {
178
179
        usage(dim);
      }
180
      for (const auto &dim : buf->strides) {
181
182
183
184
185
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
186
      for (const auto &var : usage.undefined_) {
187
188
189
        used_in_buffer_def_.insert(var.get());
      }
    }
190
    std::unordered_set<const VarNode *> used_in_buffer_def_;
191
192
193
194
195
196
197
198
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
199
200
201
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
202
203
204
205
206
207
};

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
208
209
public:
  static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
210
211
                        Optional<SimplifyConfig> config_opt = NullOpt) {
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
212
213
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
214
215
216
217
218
219
220

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

221
    std::unordered_set<const VarNode *> used_in_buffer_def =
222
223
224
225
226
227
228
229
230
231
232
233
234
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
235
236
237
238
239
240
241
242
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
        } else {
          param_updated = true;
243
        }
244
      }
245
246
247
    }
    // return func;
    if (param_updated) {
248
249
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
250
    } else {
251
      return func;
252
253
254
    }
  }

255
256
257
258
259
260
261
262
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
      : IRMutatorWithAnalyzer(analyzer), config_(config),
        touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
  }
263
264
265
266
267
268

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

269
  PrimExpr VisitExpr(const PrimExpr &expr) final {
270
    if (config_->propagate_knowns_to_simplify_expressions) {
271
272
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
273
274
275
276
277
278
279
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

280
  Stmt VisitStmt(const Stmt &stmt) override {
281
282
283
284
285
286
287
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

288
  Stmt VisitStmt_(const ForNode *op) final {
289
290
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
291
292
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
293
294
295
    return Parent::VisitStmt_(op);
  }

296
297
298
299
300
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
301
    // Won't face the deep expression explosion problem as in Let expression.
302
303
304
305
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
306
307
308
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

309
  Stmt VisitStmt_(const LetStmtNode *op) override {
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    PrimExpr value = this->VisitExpr(op->value);
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      // It is usually fine to discard the let binding because the
      // call to simplify will always inline the var.
      //
      // The exception is when the variable is used in a Buffer's
      // definition, as these are not updated by the simplification.
      // After DeclBuffer is required prior to use of a buffer,
      // simplifying can update the buffer definition as well.  The
      // buffer can only be updated at its point of definition,
      // because the points of use may occur within contexts that
      // allow for additional simplifications (e.g. a buffer of shape
      // [i,j] whose first use occurs within "if i==1" should not have
      // its shape simplified to [1,j]).
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      // Even if we aren't replacing all occurrences, they may be
      // necessary for proving conditional statements.
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    // TODO(Lunderberg): Update the Buffer object as part of
    // DeclBuffer updates, which will first require
    // https://github.com/apache/tvm/pull/14778.
    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

350
  Stmt VisitStmt_(const IfThenElseNode *op) override {
351
352
353
354
355
356
357
358
359
360
361
362
363
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

364
  PrimExpr VisitExpr_(const CallNode *op) override {
365
366
367
368
369
370
371
372
373
374
375
376
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

377
  PrimExpr VisitExpr_(const VarNode *op) override {
378
379
380
381
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

382
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
383
384
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
385
      used_buffers_.insert(buffer);
386
387
388
389
390
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
391
  Stmt VisitStmt_(const BufferStoreNode *op) override {
392
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
393
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
394
395
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
396
397
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
398
399
400
401
402
403
404
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
405
      used_buffers_.insert(buffer);
406
407
408
409
    }
    return std::move(store);
  }

410
411
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
432
433
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
434
435
436
    } else {
      condition = analyzer_->Simplify(condition);
    }
437
    if (const int64_t *as_int = as_const_int(condition)) {
438
439
440
441
442
443
444
445
446
447
448
      return Bool(*as_int);
    } else {
      return NullOpt;
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
  Optional<Stmt> current_stmt_{NullOpt};
449
450
451
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
452
453
454
455
456
};

using namespace tir::transform;

tvm::transform::Pass Simplify() {
457
458
459
460
461
462
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
    return StmtSimplifier::Apply(f, &analyzer, cfg);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
463
464
465
466
467
468
}

TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);

} // namespace tl
} // namespace tvm