simplify.cc 15.8 KB
Newer Older
1
2
3
4
5
/*!
 * \file simplify.cc
 * \brief Remove useless parameters of TL PrimFunc.
 */

6
#include <tvm/tir/buffer.h>
7
8
9
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
10
#include <tvm/tir/utils.h>
11
12
13

#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
14
#include "tir/analysis/var_use_def_analysis.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
  bool transitively_prove_inequalities;
  bool propagate_knowns_to_prove_conditional;
  bool propagate_knowns_to_simplify_expressions;
  bool convert_boolean_to_and_of_ors;
  bool apply_constraints_to_boolean_branches;

  TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
    TVM_ATTR_FIELD(transitively_prove_inequalities)
31
32
        .describe("If true, simplify conditionals with transitive combinations "
                  "of scoped constraints")
33
34
35
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
36
37
        .describe("If true, known buffer values are propagated and used to "
                  "statically prove conditionals")
38
39
40
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
41
42
43
        .describe("If true, known buffer values are propagated and used to "
                  "replace BufferLoad wherever "
                  "possible")
44
45
46
47
48
49
50
        .set_default(false);

    TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
        .describe("If true, simplify conditionals into an AND of ORs")
        .set_default(false);

    TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
51
52
        .describe("If true, simplify each branch of AND/OR "
                  "under a constraints provided by the other branch")
53
54
55
        .set_default(false);
  }

56
  RewriteSimplifier::Extension GetEnabledExtensions() const {
57
58
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
59
60
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
61
62
    }
    if (convert_boolean_to_and_of_ors) {
63
64
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
65
66
    }
    if (apply_constraints_to_boolean_branches) {
67
68
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
69
70
71
72
73
    }
    return flags;
  }
};

74
75
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
76
77
78
79
80
81
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

    Visitor(PrimFunc func) : func(func) {}

82
83
84
85
86
87
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
88
        }
89
90
      }
      StmtExprVisitor::VisitExpr_(op);
91
    }
92
    void VisitExpr_(const BufferLoadNode *op) override {
93
94
95
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
96
    void VisitStmt_(const BufferStoreNode *op) override {
97
98
99
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
100
101
102
103
104
105
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
106
        }
107
108
109
110
111
112
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
113
        }
114
115
116
117
118
119
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
120
        }
121
122
      }
      StmtExprVisitor::VisitStmt_(op);
123
124
    }

125
    void VisitBuffer(const Buffer &buf) {
126
127
128
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
129
      for (const auto &dim : buf->shape) {
130
131
        usage(dim);
      }
132
      for (const auto &dim : buf->strides) {
133
134
135
136
        usage(dim);
      }
      usage(buf->elem_offset);

137
      for (const auto &buffer : usage.buffer_use_count_) {
138
        if (buffer.second >= 1) {
139
          used_in_buffer_def_.insert(buffer.first);
140
141
        }
      }
142
      for (const auto &buffer : usage.undefined_buffers_) {
143
144
145
146
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
147
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
148
149
150
151
152
153
154
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

155
156
157
158
159
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
160
161
162
163
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

164
    void VisitExpr_(const BufferLoadNode *op) override {
165
166
167
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
168
    void VisitStmt_(const BufferStoreNode *op) override {
169
170
171
172
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

173
    void VisitBuffer(const Buffer &buf) {
174
175
176
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
177
      for (const auto &dim : buf->shape) {
178
179
        usage(dim);
      }
180
      for (const auto &dim : buf->strides) {
181
182
183
184
185
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
186
      for (const auto &var : usage.undefined_) {
187
188
189
        used_in_buffer_def_.insert(var.get());
      }
    }
190
    std::unordered_set<const VarNode *> used_in_buffer_def_;
191
192
193
194
195
196
197
198
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
199
200
201
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
202
203
204
205
206
207
};

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
208
209
public:
  static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
210
211
                        Optional<SimplifyConfig> config_opt = NullOpt,
                        bool simplify_arguments = false) {
212
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
213
214
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
215
216
217
218
219
220
221

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

222
    std::unordered_set<const VarNode *> used_in_buffer_def =
223
224
225
226
227
228
229
230
231
232
233
234
235
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
236
237
238
239
240
241
242
243
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
        } else {
          param_updated = true;
244
        }
245
      }
246
    }
247
248

    if (simplify_arguments && param_updated) {
249
250
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
251
    } else {
252
      return func;
253
254
255
    }
  }

256
257
258
259
260
261
262
263
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
      : IRMutatorWithAnalyzer(analyzer), config_(config),
        touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
  }
264
265
266
267
268
269

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

270
  PrimExpr VisitExpr(const PrimExpr &expr) final {
271
    if (config_->propagate_knowns_to_simplify_expressions) {
272
273
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
274
275
276
277
278
279
280
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

281
  Stmt VisitStmt(const Stmt &stmt) override {
282
283
284
285
286
287
288
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

289
  Stmt VisitStmt_(const ForNode *op) final {
290
291
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
292
293
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
294
295
296
    return Parent::VisitStmt_(op);
  }

297
298
299
300
301
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
302
    // Won't face the deep expression explosion problem as in Let expression.
303
304
305
306
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
307
308
309
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

310
  Stmt VisitStmt_(const LetStmtNode *op) override {
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    PrimExpr value = this->VisitExpr(op->value);
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      // It is usually fine to discard the let binding because the
      // call to simplify will always inline the var.
      //
      // The exception is when the variable is used in a Buffer's
      // definition, as these are not updated by the simplification.
      // After DeclBuffer is required prior to use of a buffer,
      // simplifying can update the buffer definition as well.  The
      // buffer can only be updated at its point of definition,
      // because the points of use may occur within contexts that
      // allow for additional simplifications (e.g. a buffer of shape
      // [i,j] whose first use occurs within "if i==1" should not have
      // its shape simplified to [1,j]).
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      // Even if we aren't replacing all occurrences, they may be
      // necessary for proving conditional statements.
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    // TODO(Lunderberg): Update the Buffer object as part of
    // DeclBuffer updates, which will first require
    // https://github.com/apache/tvm/pull/14778.
    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

351
  Stmt VisitStmt_(const IfThenElseNode *op) override {
352
353
354
355
356
357
358
359
360
361
362
363
364
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

365
  PrimExpr VisitExpr_(const CallNode *op) override {
366
367
368
369
370
371
372
373
374
375
376
377
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

378
  PrimExpr VisitExpr_(const VarNode *op) override {
379
380
381
382
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

383
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
384
385
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
386
      used_buffers_.insert(buffer);
387
388
389
390
391
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
392
  Stmt VisitStmt_(const BufferStoreNode *op) override {
393
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
394
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
395
396
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
397
398
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
399
400
401
402
403
404
405
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
406
      used_buffers_.insert(buffer);
407
408
409
410
    }
    return std::move(store);
  }

411
412
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
433
434
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
435
436
437
    } else {
      condition = analyzer_->Simplify(condition);
    }
438
    if (const int64_t *as_int = as_const_int(condition)) {
439
440
      return Bool(*as_int);
    } else {
441
442
443
444
445
446
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
447
448
449
450
451
452
453
454
455
      return NullOpt;
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
  Optional<Stmt> current_stmt_{NullOpt};
456
457
458
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
459
460
461
462
};

using namespace tir::transform;

463
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
464
465
466
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
467
    return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments);
468
469
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
470
471
472
473
474
475
}

TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);

} // namespace tl
} // namespace tvm