simplify.cc 16.9 KB
Newer Older
1
2
/*!
 * \file simplify.cc
3
4
 * \brief Statement simplifier based on analyzer and remove useless parameters
 * of TL PrimFunc.
5
6
 */

7
#include <tvm/ffi/reflection/registry.h>
8
#include <tvm/tir/buffer.h>
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
12
#include <tvm/tir/utils.h>
13
14
15

#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
16
#include "tir/analysis/var_use_def_analysis.h"
17
18
19
20
21
22
23

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

24
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
25
26
27
28
29
30
  bool transitively_prove_inequalities;
  bool propagate_knowns_to_prove_conditional;
  bool propagate_knowns_to_simplify_expressions;
  bool convert_boolean_to_and_of_ors;
  bool apply_constraints_to_boolean_branches;

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<SimplifyConfigNode>()
        .def_ro("transitively_prove_inequalities",
                &SimplifyConfigNode::transitively_prove_inequalities,
                "If true, simplify conditionals with transitive combinations "
                "of scoped constraints",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_prove_conditional",
                &SimplifyConfigNode::propagate_knowns_to_prove_conditional,
                "If true, known buffer values are propagated and used to "
                "statically prove conditionals",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_simplify_expressions",
                &SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
                "If true, known buffer values are propagated and used to "
                "replace BufferLoad wherever "
                "possible",
                refl::DefaultValue(false))
        .def_ro("convert_boolean_to_and_of_ors",
                &SimplifyConfigNode::convert_boolean_to_and_of_ors,
                "If true, simplify conditionals into an AND of ORs",
                refl::DefaultValue(false))
        .def_ro("apply_constraints_to_boolean_branches",
                &SimplifyConfigNode::apply_constraints_to_boolean_branches,
                "If true, simplify each branch of AND/OR under a constraints "
                "provided by the other "
                "branch",
                refl::DefaultValue(false));
60
  }
61
62
  static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
63

64
  RewriteSimplifier::Extension GetEnabledExtensions() const {
65
66
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
67
68
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
69
70
    }
    if (convert_boolean_to_and_of_ors) {
71
72
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
73
74
    }
    if (apply_constraints_to_boolean_branches) {
75
76
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
77
78
79
80
81
    }
    return flags;
  }
};

82
83
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
84
85
86
87
88
89
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

    Visitor(PrimFunc func) : func(func) {}

90
91
92
93
94
95
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
96
        }
97
98
      }
      StmtExprVisitor::VisitExpr_(op);
99
    }
100
    void VisitExpr_(const BufferLoadNode *op) override {
101
102
103
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
104
    void VisitStmt_(const BufferStoreNode *op) override {
105
106
107
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
108
109
110
111
112
113
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
114
        }
115
116
117
118
119
120
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
121
        }
122
123
124
125
126
127
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
128
        }
129
130
      }
      StmtExprVisitor::VisitStmt_(op);
131
132
    }

133
    void VisitBuffer(const Buffer &buf) {
134
135
136
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
137
      for (const auto &dim : buf->shape) {
138
139
        usage(dim);
      }
140
      for (const auto &dim : buf->strides) {
141
142
143
144
        usage(dim);
      }
      usage(buf->elem_offset);

145
      for (const auto &buffer : usage.buffer_use_count_) {
146
        if (buffer.second >= 1) {
147
          used_in_buffer_def_.insert(buffer.first);
148
149
        }
      }
150
      for (const auto &buffer : usage.undefined_buffers_) {
151
152
153
154
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
155
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
156
157
158
159
160
161
162
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

163
164
165
166
167
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
168
169
170
171
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

172
    void VisitExpr_(const BufferLoadNode *op) override {
173
174
175
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
176
    void VisitStmt_(const BufferStoreNode *op) override {
177
178
179
180
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

181
    void VisitBuffer(const Buffer &buf) {
182
183
184
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
185
      for (const auto &dim : buf->shape) {
186
187
        usage(dim);
      }
188
      for (const auto &dim : buf->strides) {
189
190
191
192
193
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
194
      for (const auto &var : usage.undefined_) {
195
196
197
        used_in_buffer_def_.insert(var.get());
      }
    }
198
    std::unordered_set<const VarNode *> used_in_buffer_def_;
199
200
201
202
203
204
205
206
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
207
208
209
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
210
};
211
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
212
213
214
215
216

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
217
218
public:
  static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
219
                        Optional<SimplifyConfig> config_opt = std::nullopt,
220
                        bool simplify_arguments = false) {
221
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
222
223
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
224
225
226
227
228
229
230

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

231
    std::unordered_set<const VarNode *> used_in_buffer_def =
232
233
234
235
236
237
238
239
240
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
241

242
243
244
245
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
246
247
248
249
250
251
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
252
253
254
255
256
        } else if (simplifier.used_in_buffer_def_.find(
                       func->buffer_map[var]->data.get()) !=
                   simplifier.used_in_buffer_def_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
257
258
        } else {
          param_updated = true;
259
        }
260
      }
261
    }
262

263
    if (param_updated) {
264
265
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
266
    } else {
267
      return func;
268
269
270
    }
  }

271
272
273
274
275
276
277
278
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
      : IRMutatorWithAnalyzer(analyzer), config_(config),
        touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
  }
279
280
281
282
283
284

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

285
  PrimExpr VisitExpr(const PrimExpr &expr) final {
286
    if (config_->propagate_knowns_to_simplify_expressions) {
287
288
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
289
290
291
292
293
294
295
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

296
  Stmt VisitStmt(const Stmt &stmt) override {
297
298
299
300
301
302
303
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

304
  Stmt VisitStmt_(const ForNode *op) final {
305
306
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
307
308
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
309
310
311
    return Parent::VisitStmt_(op);
  }

312
313
314
315
316
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
317
    // Won't face the deep expression explosion problem as in Let expression.
318
319
320
321
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
322
323
324
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

325
  Stmt VisitStmt_(const LetStmtNode *op) override {
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    PrimExpr value = this->VisitExpr(op->value);
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      // It is usually fine to discard the let binding because the
      // call to simplify will always inline the var.
      //
      // The exception is when the variable is used in a Buffer's
      // definition, as these are not updated by the simplification.
      // After DeclBuffer is required prior to use of a buffer,
      // simplifying can update the buffer definition as well.  The
      // buffer can only be updated at its point of definition,
      // because the points of use may occur within contexts that
      // allow for additional simplifications (e.g. a buffer of shape
      // [i,j] whose first use occurs within "if i==1" should not have
      // its shape simplified to [1,j]).
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      // Even if we aren't replacing all occurrences, they may be
      // necessary for proving conditional statements.
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    // TODO(Lunderberg): Update the Buffer object as part of
    // DeclBuffer updates, which will first require
    // https://github.com/apache/tvm/pull/14778.
    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

366
  Stmt VisitStmt_(const IfThenElseNode *op) override {
367
368
369
370
371
372
373
374
375
376
377
378
379
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

380
  PrimExpr VisitExpr_(const CallNode *op) override {
381
382
383
384
385
386
387
388
389
390
391
392
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

393
  PrimExpr VisitExpr_(const VarNode *op) override {
394
395
396
397
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

398
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
399
400
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
401
      used_buffers_.insert(buffer);
402
403
404
405
406
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
407
  Stmt VisitStmt_(const BufferStoreNode *op) override {
408
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
409
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
410
411
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
412
413
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
414
415
416
417
418
419
420
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
421
      used_buffers_.insert(buffer);
422
423
424
425
    }
    return std::move(store);
  }

426
427
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
448
449
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
450
451
452
    } else {
      condition = analyzer_->Simplify(condition);
    }
453
    if (const int64_t *as_int = as_const_int(condition)) {
454
455
      return Bool(*as_int);
    } else {
456
457
458
459
460
461
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
462
      return std::nullopt;
463
464
465
466
467
468
469
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
470
  Optional<Stmt> current_stmt_{std::nullopt};
471
472
473
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
474
475
476
477
};

using namespace tir::transform;

478
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
479
480
481
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
482
    return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments);
483
484
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
485
486
}

487
488
489
490
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
491
492
493

} // namespace tl
} // namespace tvm