simplify.cc 18 KB
Newer Older
1
2
/*!
 * \file simplify.cc
3
4
 * \brief Statement simplifier based on analyzer and remove useless parameters
 * of TL PrimFunc.
5
6
 */

7
#include <tvm/ffi/reflection/registry.h>
8
#include <tvm/tir/analysis.h>
9
#include <tvm/tir/buffer.h>
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
13
#include <tvm/tir/utils.h>
14

15
#include <optional>
16
17
#include <utility>

18
19
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
20
#include "tir/analysis/var_use_def_analysis.h"
21
22
23
24
25
26
27

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

28
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
29
30
31
32
33
  bool transitively_prove_inequalities{};
  bool propagate_knowns_to_prove_conditional{};
  bool propagate_knowns_to_simplify_expressions{};
  bool convert_boolean_to_and_of_ors{};
  bool apply_constraints_to_boolean_branches{};
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<SimplifyConfigNode>()
        .def_ro("transitively_prove_inequalities",
                &SimplifyConfigNode::transitively_prove_inequalities,
                "If true, simplify conditionals with transitive combinations "
                "of scoped constraints",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_prove_conditional",
                &SimplifyConfigNode::propagate_knowns_to_prove_conditional,
                "If true, known buffer values are propagated and used to "
                "statically prove conditionals",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_simplify_expressions",
                &SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
                "If true, known buffer values are propagated and used to "
                "replace BufferLoad wherever "
                "possible",
                refl::DefaultValue(false))
        .def_ro("convert_boolean_to_and_of_ors",
                &SimplifyConfigNode::convert_boolean_to_and_of_ors,
                "If true, simplify conditionals into an AND of ORs",
                refl::DefaultValue(false))
        .def_ro("apply_constraints_to_boolean_branches",
                &SimplifyConfigNode::apply_constraints_to_boolean_branches,
                "If true, simplify each branch of AND/OR under a constraints "
                "provided by the other "
                "branch",
                refl::DefaultValue(false));
64
  }
65
66
  static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
67

68
  RewriteSimplifier::Extension GetEnabledExtensions() const {
69
70
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
71
72
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
73
74
    }
    if (convert_boolean_to_and_of_ors) {
75
76
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
77
78
    }
    if (apply_constraints_to_boolean_branches) {
79
80
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
81
82
83
84
85
    }
    return flags;
  }
};

86
87
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
88
89
90
91
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

92
    Visitor(PrimFunc func) : func(std::move(func)) {}
93

94
95
96
97
98
99
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
100
        }
101
102
      }
      StmtExprVisitor::VisitExpr_(op);
103
    }
104
    void VisitExpr_(const BufferLoadNode *op) override {
105
106
107
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
108
    void VisitStmt_(const BufferStoreNode *op) override {
109
110
111
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
112
113
114
115
116
117
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
118
        }
119
120
121
122
123
124
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
125
        }
126
127
128
129
130
131
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
132
        }
133
134
      }
      StmtExprVisitor::VisitStmt_(op);
135
136
    }

137
    void VisitBuffer(const Buffer &buf) {
138
139
140
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
141
      for (const auto &dim : buf->shape) {
142
143
        usage(dim);
      }
144
      for (const auto &dim : buf->strides) {
145
146
147
148
        usage(dim);
      }
      usage(buf->elem_offset);

149
      for (const auto &buffer : usage.buffer_use_count_) {
150
        if (buffer.second >= 1) {
151
          used_in_buffer_def_.insert(buffer.first);
152
153
        }
      }
154
      for (const auto &buffer : usage.undefined_buffers_) {
155
156
157
158
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
159
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
160
161
162
163
164
165
166
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

167
168
169
170
171
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
172
173
174
175
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

176
    void VisitExpr_(const BufferLoadNode *op) override {
177
178
179
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
180
    void VisitStmt_(const BufferStoreNode *op) override {
181
182
183
184
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

185
    void VisitBuffer(const Buffer &buf) {
186
187
188
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
189
      for (const auto &dim : buf->shape) {
190
191
        usage(dim);
      }
192
      for (const auto &dim : buf->strides) {
193
194
195
196
197
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
198
      for (const auto &var : usage.undefined_) {
199
200
201
        used_in_buffer_def_.insert(var.get());
      }
    }
202
    std::unordered_set<const VarNode *> used_in_buffer_def_;
203
204
205
206
207
208
209
210
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
211
212
213
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
214
};
215
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
216
217
218
219
220

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
221
public:
222
223
224
225
  static PrimFunc
  Apply(PrimFunc func, Analyzer *analyzer,
        const Optional<SimplifyConfig> &config_opt = std::nullopt,
        bool simplify_arguments = false) {
226
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
227
228
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
229
230
231
232
233
234
235

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

236
    std::unordered_set<const VarNode *> used_in_buffer_def =
237
238
239
240
241
242
243
244
245
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
246

247
248
249
250
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
251
252
253
254
255
256
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
257
258
259
260
261
        } else if (simplifier.used_in_buffer_def_.find(
                       func->buffer_map[var]->data.get()) !=
                   simplifier.used_in_buffer_def_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
262
263
        } else {
          param_updated = true;
264
        }
265
      }
266
    }
267

268
    if (param_updated) {
269
270
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
271
    } else {
272
      return func;
273
274
275
    }
  }

276
277
278
279
280
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
281
282
283
      : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
        touch_pattern_(std::move(touch_pattern)),
        used_in_buffer_def_(std::move(used_in_buffer_def)) {}
284
285
286
287
288
289

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

290
  PrimExpr VisitExpr(const PrimExpr &expr) final {
291
    if (config_->propagate_knowns_to_simplify_expressions) {
292
293
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
294
295
296
297
298
299
300
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

301
  Stmt VisitStmt(const Stmt &stmt) override {
302
303
304
305
306
307
308
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

309
  Stmt VisitStmt_(const ForNode *op) final {
310
311
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
312
313
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
314
315
316
    return Parent::VisitStmt_(op);
  }

317
318
319
320
321
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
322
    // Won't face the deep expression explosion problem as in Let expression.
323
324
325
326
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
327
328
329
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

330
  Stmt VisitStmt_(const LetStmtNode *op) override {
331
    PrimExpr value = this->VisitExpr(op->value);
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    bool remove_buffer_alias = false;
    // TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate
    // fragment types. TVM currently reinterprets vectorized/shared accesses as
    // Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later
    // passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes
    // that our layout can't handle. Force-inline (by dropping the let) whenever
    // the alias spans more than 2 dims or carries vector lanes.
    auto get_ranges = [&](const PrimExpr &expr) -> Array<Range> {
      Array<Range> ranges;
      if (const auto *load = expr.as<BufferLoadNode>()) {
        for (const PrimExpr &index : load->indices) {
          if (const auto *ramp = index.as<RampNode>()) {
            ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
          } else {
            ranges.push_back(Range::FromMinExtent(index, Integer(1)));
          }
        }
      } else if (const auto *region = expr.as<BufferRegionNode>()) {
        for (const Range &range : region->region) {
          ranges.push_back(range);
        }
      }
      return ranges;
    };
    Array<Range> ranges = get_ranges(value);
    if (!ranges.empty()) {
      int non_unit_dims = 0;
      for (const Range &range : ranges) {
        PrimExpr extent = analyzer_->Simplify(range->extent);
        if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) {
          continue;
        }
        ++non_unit_dims;
        if (non_unit_dims > 1) {
          remove_buffer_alias = true;
          break;
        }
      }
    }
    if (remove_buffer_alias) {
      Stmt body = this->VisitStmt(op->body);
      bool used = UsesVar(
          body, [&](const VarNode *var) { return var == op->var.get(); });
      ICHECK(!used) << "Let binding of BufferLoad is expected to be unused "
                       "before removal "
                    << op->var << " : " << op->value << " .";
      return body;
    }

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

403
  Stmt VisitStmt_(const IfThenElseNode *op) override {
404
405
406
407
408
409
410
411
412
413
414
415
416
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

417
  PrimExpr VisitExpr_(const CallNode *op) override {
418
419
420
421
422
423
424
425
426
427
428
429
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

430
  PrimExpr VisitExpr_(const VarNode *op) override {
431
432
433
434
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

435
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
436
437
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
438
      used_buffers_.insert(buffer);
439
440
441
442
443
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
444
  Stmt VisitStmt_(const BufferStoreNode *op) override {
445
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
446
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
447
448
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
449
450
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
451
452
453
454
455
456
457
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
458
      used_buffers_.insert(buffer);
459
460
461
462
    }
    return std::move(store);
  }

463
464
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
485
486
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
487
488
489
    } else {
      condition = analyzer_->Simplify(condition);
    }
490
    if (const int64_t *as_int = as_const_int(condition)) {
491
492
      return Bool(*as_int);
    } else {
493
494
495
496
497
498
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
499
      return std::nullopt;
500
501
502
503
504
505
506
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
507
  Optional<Stmt> current_stmt_{std::nullopt};
508
509
510
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
511
512
513
514
};

using namespace tir::transform;

515
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
516
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
517
518
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
519
520
    return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
                                 simplify_arguments);
521
522
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
523
524
}

525
526
527
528
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
529
530
531

} // namespace tl
} // namespace tvm