"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "a39cd633f88b84f4f6230eb0bf8eda5a52030c8b"
simplify.cc 18.6 KB
Newer Older
1
2
/*!
 * \file simplify.cc
3
4
 * \brief Statement simplifier based on analyzer and remove useless parameters
 * of TL PrimFunc.
5
6
 */

7
#include <tvm/ffi/reflection/registry.h>
8
#include <tvm/tir/analysis.h>
9
#include <tvm/tir/buffer.h>
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
13
#include <tvm/tir/utils.h>
14

15
#include <optional>
16
17
#include <utility>

18
19
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
20
#include "tir/analysis/var_use_def_analysis.h"
21
22
23
24
25

namespace tvm {
namespace tl {

using namespace tir;
26
using namespace ffi;
27
28
using namespace arith;

29
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
30
31
32
33
34
  bool transitively_prove_inequalities{};
  bool propagate_knowns_to_prove_conditional{};
  bool propagate_knowns_to_simplify_expressions{};
  bool convert_boolean_to_and_of_ors{};
  bool apply_constraints_to_boolean_branches{};
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<SimplifyConfigNode>()
        .def_ro("transitively_prove_inequalities",
                &SimplifyConfigNode::transitively_prove_inequalities,
                "If true, simplify conditionals with transitive combinations "
                "of scoped constraints",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_prove_conditional",
                &SimplifyConfigNode::propagate_knowns_to_prove_conditional,
                "If true, known buffer values are propagated and used to "
                "statically prove conditionals",
                refl::DefaultValue(false))
        .def_ro("propagate_knowns_to_simplify_expressions",
                &SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
                "If true, known buffer values are propagated and used to "
                "replace BufferLoad wherever "
                "possible",
                refl::DefaultValue(false))
        .def_ro("convert_boolean_to_and_of_ors",
                &SimplifyConfigNode::convert_boolean_to_and_of_ors,
                "If true, simplify conditionals into an AND of ORs",
                refl::DefaultValue(false))
        .def_ro("apply_constraints_to_boolean_branches",
                &SimplifyConfigNode::apply_constraints_to_boolean_branches,
                "If true, simplify each branch of AND/OR under a constraints "
                "provided by the other "
                "branch",
                refl::DefaultValue(false));
65
  }
66
67
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig",
                                    SimplifyConfigNode, BaseAttrsNode);
68

69
  RewriteSimplifier::Extension GetEnabledExtensions() const {
70
71
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
72
73
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
74
75
    }
    if (convert_boolean_to_and_of_ors) {
76
77
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
78
79
    }
    if (apply_constraints_to_boolean_branches) {
80
81
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
82
83
84
85
86
    }
    return flags;
  }
};

87
88
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
89
90
91
92
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

93
    Visitor(PrimFunc func) : func(std::move(func)) {}
94

95
96
97
98
99
100
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
101
        }
102
103
      }
      StmtExprVisitor::VisitExpr_(op);
104
    }
105
    void VisitExpr_(const BufferLoadNode *op) override {
106
107
108
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
109
    void VisitStmt_(const BufferStoreNode *op) override {
110
111
112
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
113
114
115
116
117
118
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
119
        }
120
121
122
123
124
125
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
126
        }
127
128
129
130
131
132
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
133
        }
134
135
      }
      StmtExprVisitor::VisitStmt_(op);
136
137
    }

138
    void VisitBuffer(const Buffer &buf) {
139
140
141
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
142
      for (const auto &dim : buf->shape) {
143
144
        usage(dim);
      }
145
      for (const auto &dim : buf->strides) {
146
147
148
149
        usage(dim);
      }
      usage(buf->elem_offset);

150
      for (const auto &buffer : usage.buffer_use_count_) {
151
        if (buffer.second >= 1) {
152
          used_in_buffer_def_.insert(buffer.first);
153
154
        }
      }
155
      for (const auto &buffer : usage.undefined_buffers_) {
156
157
158
159
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
160
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
161
162
163
164
165
166
167
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

168
169
170
171
172
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
173
174
175
176
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

177
    void VisitExpr_(const BufferLoadNode *op) override {
178
179
180
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
181
    void VisitStmt_(const BufferStoreNode *op) override {
182
183
184
185
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

186
    void VisitBuffer(const Buffer &buf) {
187
188
189
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
190
      for (const auto &dim : buf->shape) {
191
192
        usage(dim);
      }
193
      for (const auto &dim : buf->strides) {
194
195
196
197
198
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
199
      for (const auto &var : usage.undefined_) {
200
201
202
        used_in_buffer_def_.insert(var.get());
      }
    }
203
    std::unordered_set<const VarNode *> used_in_buffer_def_;
204
205
206
207
208
209
210
211
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
212
public:
213
214
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
                                                SimplifyConfigNode);
215
};
216
TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
217
218
219
220

TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
221
public:
222
223
224
225
  static PrimFunc
  Apply(PrimFunc func, Analyzer *analyzer,
        const Optional<SimplifyConfig> &config_opt = std::nullopt,
        bool simplify_arguments = false) {
226
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
227
228
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
229
230
231
232
233
234
235

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

236
    std::unordered_set<const VarNode *> used_in_buffer_def =
237
238
239
240
241
242
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

243
244
245
246
    // Optionally remove unused buffer parameters
    if (simplify_arguments) {
      // First get used buffers
      simplifier.used_buffers_ = CollectUsedBuffers(func);
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
      bool param_updated = false;
      Array<Var> new_params;
      Map<Var, Buffer> new_buffer_map;
      // Check whether each buffer is used
      for (const auto &var : func->params) {
        if (func->buffer_map.find(var) != func->buffer_map.end()) {
          if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
              simplifier.used_buffers_.end()) {
            new_params.push_back(var);
            new_buffer_map.Set(var, func->buffer_map[var]);
          } else if (simplifier.used_in_buffer_def_.find(
                         func->buffer_map[var]->data.get()) !=
                     simplifier.used_in_buffer_def_.end()) {
            new_params.push_back(var);
            new_buffer_map.Set(var, func->buffer_map[var]);
          } else {
            param_updated = true;
          }
266
        } else {
267
268
          // Non-buffer parameters (e.g., scalars) are always retained
          new_params.push_back(var);
269
        }
270
      }
271

272
273
274
275
      if (param_updated) {
        return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                        new_buffer_map, func->attrs, func->span);
      }
276
    }
277
278
    // Either no change to params or argument simplification disabled
    return func;
279
280
  }

281
282
283
284
285
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
286
287
288
      : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
        touch_pattern_(std::move(touch_pattern)),
        used_in_buffer_def_(std::move(used_in_buffer_def)) {}
289
290
291
292
293
294

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

295
  PrimExpr VisitExpr(const PrimExpr &expr) final {
296
    if (config_->propagate_knowns_to_simplify_expressions) {
297
298
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
299
300
301
302
303
304
305
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

306
  Stmt VisitStmt(const Stmt &stmt) override {
307
308
309
310
311
312
313
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

314
  Stmt VisitStmt_(const ForNode *op) final {
315
316
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
317
318
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
319
320
321
    return Parent::VisitStmt_(op);
  }

322
323
324
325
326
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
327
    // Won't face the deep expression explosion problem as in Let expression.
328
329
330
331
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
332
333
334
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

335
  Stmt VisitStmt_(const LetStmtNode *op) override {
336
    PrimExpr value = this->VisitExpr(op->value);
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    bool remove_buffer_alias = false;
    // TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate
    // fragment types. TVM currently reinterprets vectorized/shared accesses as
    // Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later
    // passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes
    // that our layout can't handle. Force-inline (by dropping the let) whenever
    // the alias spans more than 2 dims or carries vector lanes.
    auto get_ranges = [&](const PrimExpr &expr) -> Array<Range> {
      Array<Range> ranges;
      if (const auto *load = expr.as<BufferLoadNode>()) {
        for (const PrimExpr &index : load->indices) {
          if (const auto *ramp = index.as<RampNode>()) {
            ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
          } else {
            ranges.push_back(Range::FromMinExtent(index, Integer(1)));
          }
        }
      } else if (const auto *region = expr.as<BufferRegionNode>()) {
        for (const Range &range : region->region) {
          ranges.push_back(range);
        }
      }
      return ranges;
    };
    Array<Range> ranges = get_ranges(value);
    if (!ranges.empty()) {
      int non_unit_dims = 0;
      for (const Range &range : ranges) {
        PrimExpr extent = analyzer_->Simplify(range->extent);
        if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) {
          continue;
        }
        ++non_unit_dims;
        if (non_unit_dims > 1) {
          remove_buffer_alias = true;
          break;
        }
      }
    }
    if (remove_buffer_alias) {
      Stmt body = this->VisitStmt(op->body);
      bool used = UsesVar(
          body, [&](const VarNode *var) { return var == op->var.get(); });
      ICHECK(!used) << "Let binding of BufferLoad is expected to be unused "
                       "before removal "
                    << op->var << " : " << op->value << " .";
      return body;
    }

386
387
388
389
390
391
392
393
394
395
396
397
398
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
399
      return tvm::ffi::GetRef<Stmt>(op);
400
401
402
403
404
405
406
407
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

408
  Stmt VisitStmt_(const IfThenElseNode *op) override {
409
410
411
412
413
414
415
416
417
418
419
420
421
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

422
  PrimExpr VisitExpr_(const CallNode *op) override {
423
424
425
426
427
428
429
430
431
432
433
434
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

435
  PrimExpr VisitExpr_(const VarNode *op) override {
436
437
438
439
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

440
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
441
442
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
443
      used_buffers_.insert(buffer);
444
445
446
447
448
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
449
  Stmt VisitStmt_(const BufferStoreNode *op) override {
450
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
451
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
452
453
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
454
455
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
456
457
458
459
460
461
462
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
463
      used_buffers_.insert(buffer);
464
465
466
467
    }
    return std::move(store);
  }

468
469
470
471
472
473
474
475
476
477
  Stmt VisitStmt_(const AttrStmtNode *op) override {
    if (op->attr_key == "tl.assume") {
      PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
      auto n = CopyOnWrite(op);
      n->node = std::move(condition);
      return Parent::VisitStmt_(n.get());
    }
    return Parent::VisitStmt_(op);
  }

478
479
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
500
501
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
502
503
504
    } else {
      condition = analyzer_->Simplify(condition);
    }
505
    if (const int64_t *as_int = as_const_int(condition)) {
506
507
      return Bool(*as_int);
    } else {
508
509
510
511
512
513
      // May have symbolic, need kSymbolicBound level prover.
      if (analyzer_->CanProve(condition) ||
          analyzer_->CanProve(condition,
                              arith::ProofStrength::kSymbolicBound)) {
        return Bool(true);
      }
514
      return std::nullopt;
515
516
517
518
519
520
521
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
522
  Optional<Stmt> current_stmt_{std::nullopt};
523
524
525
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
526
527
528
529
};

using namespace tir::transform;

530
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
531
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
532
533
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
534
535
    return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
                                 simplify_arguments);
536
537
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
538
539
}

540
TVM_FFI_STATIC_INIT_BLOCK() {
541
542
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.Simplify", Simplify);
543
}
544
545
546

} // namespace tl
} // namespace tvm