simplify.cc 18 KB
Newer Older
1
2
/*!
 * \file simplify.cc
3
4
 * \brief Statement simplifier based on analyzer and remove useless parameters
 * of TL PrimFunc.
5
6
 */

7
#include <tvm/ffi/reflection/registry.h>
8
#include <tvm/tir/analysis.h>
9
#include <tvm/tir/buffer.h>
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
13
#include <tvm/tir/utils.h>
14

15
#include <optional>
16
17
#include <utility>

18
19
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
20
#include "tir/analysis/var_use_def_analysis.h"
21
22
23
24
25

namespace tvm {
namespace tl {

using namespace tir;
26
using namespace ffi;
27
28
using namespace arith;

29
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
30
31
32
33
34
  bool transitively_prove_inequalities{};
  bool propagate_knowns_to_prove_conditional{};
  bool propagate_knowns_to_simplify_expressions{};
  bool convert_boolean_to_and_of_ors{};
  bool apply_constraints_to_boolean_branches{};
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<SimplifyConfigNode>()
        .def_ro("transitively_prove_inequalities",
                &SimplifyConfigNode::transitively_prove_inequalities,
                "If true, simplify conditionals with transitive combinations "
                "of scoped constraints",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_prove_conditional",
                &SimplifyConfigNode::propagate_knowns_to_prove_conditional,
                "If true, known buffer values are propagated and used to "
                "statically prove conditionals",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_simplify_expressions",
                &SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
                "If true, known buffer values are propagated and used to "
                "replace BufferLoad wherever "
                "possible",
                refl::DefaultValue(false))
        .def_ro("convert_boolean_to_and_of_ors",
                &SimplifyConfigNode::convert_boolean_to_and_of_ors,
                "If true, simplify conditionals into an AND of ORs",
                refl::DefaultValue(false))
        .def_ro("apply_constraints_to_boolean_branches",
                &SimplifyConfigNode::apply_constraints_to_boolean_branches,
                "If true, simplify each branch of AND/OR under a constraints "
                "provided by the other "
                "branch",
                refl::DefaultValue(false));
65
  }
66
67
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig",
                                    SimplifyConfigNode, BaseAttrsNode);
68

69
  RewriteSimplifier::Extension GetEnabledExtensions() const {
70
71
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
72
73
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
74
75
    }
    if (convert_boolean_to_and_of_ors) {
76
77
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
78
79
    }
    if (apply_constraints_to_boolean_branches) {
80
81
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
82
83
84
85
86
    }
    return flags;
  }
};

87
88
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
89
90
91
92
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

93
    Visitor(PrimFunc func) : func(std::move(func)) {}
94

95
96
97
98
99
100
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
101
        }
102
103
      }
      StmtExprVisitor::VisitExpr_(op);
104
    }
105
    void VisitExpr_(const BufferLoadNode *op) override {
106
107
108
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
109
    void VisitStmt_(const BufferStoreNode *op) override {
110
111
112
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
113
114
115
116
117
118
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
119
        }
120
121
122
123
124
125
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
126
        }
127
128
129
130
131
132
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
133
        }
134
135
      }
      StmtExprVisitor::VisitStmt_(op);
136
137
    }

138
    void VisitBuffer(const Buffer &buf) {
139
140
141
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
142
      for (const auto &dim : buf->shape) {
143
144
        usage(dim);
      }
145
      for (const auto &dim : buf->strides) {
146
147
148
149
        usage(dim);
      }
      usage(buf->elem_offset);

150
      for (const auto &buffer : usage.buffer_use_count_) {
151
        if (buffer.second >= 1) {
152
          used_in_buffer_def_.insert(buffer.first);
153
154
        }
      }
155
      for (const auto &buffer : usage.undefined_buffers_) {
156
157
158
159
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
160
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
161
162
163
164
165
166
167
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

168
169
170
171
172
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
173
174
175
176
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

177
    void VisitExpr_(const BufferLoadNode *op) override {
178
179
180
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
181
    void VisitStmt_(const BufferStoreNode *op) override {
182
183
184
185
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

186
    void VisitBuffer(const Buffer &buf) {
187
188
189
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
190
      for (const auto &dim : buf->shape) {
191
192
        usage(dim);
      }
193
      for (const auto &dim : buf->strides) {
194
195
196
197
198
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
199
      for (const auto &var : usage.undefined_) {
200
201
202
        used_in_buffer_def_.insert(var.get());
      }
    }
203
    std::unordered_set<const VarNode *> used_in_buffer_def_;
204
205
206
207
208
209
210
211
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
212
public:
213
214
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
                                                SimplifyConfigNode);
215
};
216
TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
217
218
219
220

TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
221
public:
222
223
224
225
  static PrimFunc
  Apply(PrimFunc func, Analyzer *analyzer,
        const Optional<SimplifyConfig> &config_opt = std::nullopt,
        bool simplify_arguments = false) {
226
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
227
228
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
229
230
231
232
233
234
235

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

236
    std::unordered_set<const VarNode *> used_in_buffer_def =
237
238
239
240
241
242
243
244
245
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
246

247
248
249
250
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
251
252
253
254
255
256
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
257
258
259
260
261
        } else if (simplifier.used_in_buffer_def_.find(
                       func->buffer_map[var]->data.get()) !=
                   simplifier.used_in_buffer_def_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
262
263
        } else {
          param_updated = true;
264
        }
265
      }
266
    }
267

268
    if (param_updated) {
269
270
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
271
    } else {
272
      return func;
273
274
275
    }
  }

276
277
278
279
280
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
281
282
283
      : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
        touch_pattern_(std::move(touch_pattern)),
        used_in_buffer_def_(std::move(used_in_buffer_def)) {}
284
285
286
287
288
289

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

290
  PrimExpr VisitExpr(const PrimExpr &expr) final {
291
    if (config_->propagate_knowns_to_simplify_expressions) {
292
293
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
294
295
296
297
298
299
300
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

301
  Stmt VisitStmt(const Stmt &stmt) override {
302
303
304
305
306
307
308
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

309
  Stmt VisitStmt_(const ForNode *op) final {
310
311
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
312
313
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
314
315
316
    return Parent::VisitStmt_(op);
  }

317
318
319
320
321
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
322
    // Won't face the deep expression explosion problem as in Let expression.
323
324
325
326
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
327
328
329
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

330
  Stmt VisitStmt_(const LetStmtNode *op) override {
331
    PrimExpr value = this->VisitExpr(op->value);
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    bool remove_buffer_alias = false;
    // TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate
    // fragment types. TVM currently reinterprets vectorized/shared accesses as
    // Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later
    // passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes
    // that our layout can't handle. Force-inline (by dropping the let) whenever
    // the alias spans more than 2 dims or carries vector lanes.
    auto get_ranges = [&](const PrimExpr &expr) -> Array<Range> {
      Array<Range> ranges;
      if (const auto *load = expr.as<BufferLoadNode>()) {
        for (const PrimExpr &index : load->indices) {
          if (const auto *ramp = index.as<RampNode>()) {
            ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
          } else {
            ranges.push_back(Range::FromMinExtent(index, Integer(1)));
          }
        }
      } else if (const auto *region = expr.as<BufferRegionNode>()) {
        for (const Range &range : region->region) {
          ranges.push_back(range);
        }
      }
      return ranges;
    };
    Array<Range> ranges = get_ranges(value);
    if (!ranges.empty()) {
      int non_unit_dims = 0;
      for (const Range &range : ranges) {
        PrimExpr extent = analyzer_->Simplify(range->extent);
        if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) {
          continue;
        }
        ++non_unit_dims;
        if (non_unit_dims > 1) {
          remove_buffer_alias = true;
          break;
        }
      }
    }
    if (remove_buffer_alias) {
      Stmt body = this->VisitStmt(op->body);
      bool used = UsesVar(
          body, [&](const VarNode *var) { return var == op->var.get(); });
      ICHECK(!used) << "Let binding of BufferLoad is expected to be unused "
                       "before removal "
                    << op->var << " : " << op->value << " .";
      return body;
    }

381
382
383
384
385
386
387
388
389
390
391
392
393
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
394
      return tvm::ffi::GetRef<Stmt>(op);
395
396
397
398
399
400
401
402
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

403
  Stmt VisitStmt_(const IfThenElseNode *op) override {
404
405
406
407
408
409
410
411
412
413
414
415
416
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

417
  PrimExpr VisitExpr_(const CallNode *op) override {
418
419
420
421
422
423
424
425
426
427
428
429
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

430
  PrimExpr VisitExpr_(const VarNode *op) override {
431
432
433
434
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

435
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
436
437
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
438
      used_buffers_.insert(buffer);
439
440
441
442
443
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
444
  Stmt VisitStmt_(const BufferStoreNode *op) override {
445
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
446
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
447
448
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
449
450
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
451
452
453
454
455
456
457
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
458
      used_buffers_.insert(buffer);
459
460
461
462
    }
    return std::move(store);
  }

463
464
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
485
486
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
487
488
489
    } else {
      condition = analyzer_->Simplify(condition);
    }
490
    if (const int64_t *as_int = as_const_int(condition)) {
491
492
      return Bool(*as_int);
    } else {
493
494
495
496
497
498
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
499
      return std::nullopt;
500
501
502
503
504
505
506
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
507
  Optional<Stmt> current_stmt_{std::nullopt};
508
509
510
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
511
512
513
514
};

using namespace tir::transform;

515
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
516
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
517
518
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
519
520
    return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
                                 simplify_arguments);
521
522
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
523
524
}

525
TVM_FFI_STATIC_INIT_BLOCK() {
526
527
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.Simplify", Simplify);
528
}
529
530
531

} // namespace tl
} // namespace tvm