simplify.cc 17 KB
Newer Older
1
2
/*!
 * \file simplify.cc
3
4
 * \brief Statement simplifier based on analyzer and remove useless parameters
 * of TL PrimFunc.
5
6
 */

7
#include <tvm/ffi/reflection/registry.h>
8
#include <tvm/tir/buffer.h>
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
12
#include <tvm/tir/utils.h>
13

14
15
#include <utility>

16
17
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
18
#include "tir/analysis/var_use_def_analysis.h"
19
20
21
22
23
24
25

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

26
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
27
28
29
30
31
  bool transitively_prove_inequalities{};
  bool propagate_knowns_to_prove_conditional{};
  bool propagate_knowns_to_simplify_expressions{};
  bool convert_boolean_to_and_of_ors{};
  bool apply_constraints_to_boolean_branches{};
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<SimplifyConfigNode>()
        .def_ro("transitively_prove_inequalities",
                &SimplifyConfigNode::transitively_prove_inequalities,
                "If true, simplify conditionals with transitive combinations "
                "of scoped constraints",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_prove_conditional",
                &SimplifyConfigNode::propagate_knowns_to_prove_conditional,
                "If true, known buffer values are propagated and used to "
                "statically prove conditionals",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_simplify_expressions",
                &SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
                "If true, known buffer values are propagated and used to "
                "replace BufferLoad wherever "
                "possible",
                refl::DefaultValue(false))
        .def_ro("convert_boolean_to_and_of_ors",
                &SimplifyConfigNode::convert_boolean_to_and_of_ors,
                "If true, simplify conditionals into an AND of ORs",
                refl::DefaultValue(false))
        .def_ro("apply_constraints_to_boolean_branches",
                &SimplifyConfigNode::apply_constraints_to_boolean_branches,
                "If true, simplify each branch of AND/OR under a constraints "
                "provided by the other "
                "branch",
                refl::DefaultValue(false));
62
  }
63
64
  static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
65

66
  RewriteSimplifier::Extension GetEnabledExtensions() const {
67
68
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
69
70
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
71
72
    }
    if (convert_boolean_to_and_of_ors) {
73
74
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
75
76
    }
    if (apply_constraints_to_boolean_branches) {
77
78
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
79
80
81
82
83
    }
    return flags;
  }
};

84
85
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
86
87
88
89
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

90
    Visitor(PrimFunc func) : func(std::move(func)) {}
91

92
93
94
95
96
97
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
98
        }
99
100
      }
      StmtExprVisitor::VisitExpr_(op);
101
    }
102
    void VisitExpr_(const BufferLoadNode *op) override {
103
104
105
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
106
    void VisitStmt_(const BufferStoreNode *op) override {
107
108
109
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
110
111
112
113
114
115
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
116
        }
117
118
119
120
121
122
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
123
        }
124
125
126
127
128
129
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
130
        }
131
132
      }
      StmtExprVisitor::VisitStmt_(op);
133
134
    }

135
    void VisitBuffer(const Buffer &buf) {
136
137
138
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
139
      for (const auto &dim : buf->shape) {
140
141
        usage(dim);
      }
142
      for (const auto &dim : buf->strides) {
143
144
145
146
        usage(dim);
      }
      usage(buf->elem_offset);

147
      for (const auto &buffer : usage.buffer_use_count_) {
148
        if (buffer.second >= 1) {
149
          used_in_buffer_def_.insert(buffer.first);
150
151
        }
      }
152
      for (const auto &buffer : usage.undefined_buffers_) {
153
154
155
156
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
157
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
158
159
160
161
162
163
164
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

165
166
167
168
169
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
170
171
172
173
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

174
    void VisitExpr_(const BufferLoadNode *op) override {
175
176
177
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
178
    void VisitStmt_(const BufferStoreNode *op) override {
179
180
181
182
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

183
    void VisitBuffer(const Buffer &buf) {
184
185
186
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
187
      for (const auto &dim : buf->shape) {
188
189
        usage(dim);
      }
190
      for (const auto &dim : buf->strides) {
191
192
193
194
195
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
196
      for (const auto &var : usage.undefined_) {
197
198
199
        used_in_buffer_def_.insert(var.get());
      }
    }
200
    std::unordered_set<const VarNode *> used_in_buffer_def_;
201
202
203
204
205
206
207
208
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
209
210
211
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
212
};
213
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
214
215
216
217
218

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
219
public:
220
221
222
223
  static PrimFunc
  Apply(PrimFunc func, Analyzer *analyzer,
        const Optional<SimplifyConfig> &config_opt = std::nullopt,
        bool simplify_arguments = false) {
224
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
225
226
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
227
228
229
230
231
232
233

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

234
    std::unordered_set<const VarNode *> used_in_buffer_def =
235
236
237
238
239
240
241
242
243
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
244

245
246
247
248
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
249
250
251
252
253
254
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
255
256
257
258
259
        } else if (simplifier.used_in_buffer_def_.find(
                       func->buffer_map[var]->data.get()) !=
                   simplifier.used_in_buffer_def_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
260
261
        } else {
          param_updated = true;
262
        }
263
      }
264
    }
265

266
    if (param_updated) {
267
268
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
269
    } else {
270
      return func;
271
272
273
    }
  }

274
275
276
277
278
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
279
280
281
      : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
        touch_pattern_(std::move(touch_pattern)),
        used_in_buffer_def_(std::move(used_in_buffer_def)) {}
282
283
284
285
286
287

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

288
  PrimExpr VisitExpr(const PrimExpr &expr) final {
289
    if (config_->propagate_knowns_to_simplify_expressions) {
290
291
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
292
293
294
295
296
297
298
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

299
  Stmt VisitStmt(const Stmt &stmt) override {
300
301
302
303
304
305
306
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

307
  Stmt VisitStmt_(const ForNode *op) final {
308
309
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
310
311
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
312
313
314
    return Parent::VisitStmt_(op);
  }

315
316
317
318
319
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
320
    // Won't face the deep expression explosion problem as in Let expression.
321
322
323
324
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
325
326
327
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

328
  Stmt VisitStmt_(const LetStmtNode *op) override {
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    PrimExpr value = this->VisitExpr(op->value);
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      // It is usually fine to discard the let binding because the
      // call to simplify will always inline the var.
      //
      // The exception is when the variable is used in a Buffer's
      // definition, as these are not updated by the simplification.
      // After DeclBuffer is required prior to use of a buffer,
      // simplifying can update the buffer definition as well.  The
      // buffer can only be updated at its point of definition,
      // because the points of use may occur within contexts that
      // allow for additional simplifications (e.g. a buffer of shape
      // [i,j] whose first use occurs within "if i==1" should not have
      // its shape simplified to [1,j]).
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      // Even if we aren't replacing all occurrences, they may be
      // necessary for proving conditional statements.
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    // TODO(Lunderberg): Update the Buffer object as part of
    // DeclBuffer updates, which will first require
    // https://github.com/apache/tvm/pull/14778.
    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

369
  Stmt VisitStmt_(const IfThenElseNode *op) override {
370
371
372
373
374
375
376
377
378
379
380
381
382
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

383
  PrimExpr VisitExpr_(const CallNode *op) override {
384
385
386
387
388
389
390
391
392
393
394
395
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

396
  PrimExpr VisitExpr_(const VarNode *op) override {
397
398
399
400
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

401
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
402
403
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
404
      used_buffers_.insert(buffer);
405
406
407
408
409
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
410
  Stmt VisitStmt_(const BufferStoreNode *op) override {
411
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
412
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
413
414
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
415
416
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
417
418
419
420
421
422
423
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
424
      used_buffers_.insert(buffer);
425
426
427
428
    }
    return std::move(store);
  }

429
430
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
451
452
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
453
454
455
    } else {
      condition = analyzer_->Simplify(condition);
    }
456
    if (const int64_t *as_int = as_const_int(condition)) {
457
458
      return Bool(*as_int);
    } else {
459
460
461
462
463
464
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
465
      return std::nullopt;
466
467
468
469
470
471
472
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
473
  Optional<Stmt> current_stmt_{std::nullopt};
474
475
476
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
477
478
479
480
};

using namespace tir::transform;

481
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
482
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
483
484
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
485
486
    return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
                                 simplify_arguments);
487
488
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
489
490
}

491
492
493
494
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
495
496
497

} // namespace tl
} // namespace tvm