"docs/vscode:/vscode.git/clone" did not exist on "d1c6bbae60ac455284a0bb5a96fa4991de80e9fe"
simplify.cc 15.5 KB
Newer Older
1
2
3
4
5
6
7
8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file simplify.cc
 * \brief Remove useless parameters of TL PrimFunc.
 */

9
#include <tvm/tir/buffer.h>
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
13
#include <tvm/tir/utils.h>
14
15
16

#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
17
#include "tir/analysis/var_use_def_analysis.h"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
  bool transitively_prove_inequalities;
  bool propagate_knowns_to_prove_conditional;
  bool propagate_knowns_to_simplify_expressions;
  bool convert_boolean_to_and_of_ors;
  bool apply_constraints_to_boolean_branches;

  TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
    TVM_ATTR_FIELD(transitively_prove_inequalities)
34
35
        .describe("If true, simplify conditionals with transitive combinations "
                  "of scoped constraints")
36
37
38
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
39
40
        .describe("If true, known buffer values are propagated and used to "
                  "statically prove conditionals")
41
42
43
        .set_default(false);

    TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
44
45
46
        .describe("If true, known buffer values are propagated and used to "
                  "replace BufferLoad wherever "
                  "possible")
47
48
49
50
51
52
53
        .set_default(false);

    TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
        .describe("If true, simplify conditionals into an AND of ORs")
        .set_default(false);

    TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
54
55
        .describe("If true, simplify each branch of AND/OR "
                  "under a constraints provided by the other branch")
56
57
58
        .set_default(false);
  }

59
  RewriteSimplifier::Extension GetEnabledExtensions() const {
60
61
    RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
    if (transitively_prove_inequalities) {
62
63
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kTransitivelyProveInequalities);
64
65
    }
    if (convert_boolean_to_and_of_ors) {
66
67
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
68
69
    }
    if (apply_constraints_to_boolean_branches) {
70
71
      flags = RewriteSimplifier::Extension(
          flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
72
73
74
75
76
    }
    return flags;
  }
};

77
78
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
79
80
81
82
83
84
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

    Visitor(PrimFunc func) : func(func) {}

85
86
87
88
89
90
    void VisitExpr_(const CallNode *op) override {
      for (const auto &arg : op->args) {
        for (const auto &it : func->buffer_map) {
          if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
            used_in_buffer_def_.insert(it.second.get());
          }
91
        }
92
93
      }
      StmtExprVisitor::VisitExpr_(op);
94
    }
95
    void VisitExpr_(const BufferLoadNode *op) override {
96
97
98
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
99
    void VisitStmt_(const BufferStoreNode *op) override {
100
101
102
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }
103
104
105
106
107
108
    void VisitStmt_(const BlockNode *op) override {
      for (const auto &buffer : op->alloc_buffers) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
109
        }
110
111
112
113
114
115
      }
      for (const auto &buffer : op->reads) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
116
        }
117
118
119
120
121
122
      }
      for (const auto &buffer : op->writes) {
        for (const auto &it : func->buffer_map) {
          if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
            used_in_buffer_def_.insert(it.second.get());
          }
123
        }
124
125
      }
      StmtExprVisitor::VisitStmt_(op);
126
127
    }

128
    void VisitBuffer(const Buffer &buf) {
129
130
131
      // Collect buffers that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
132
      for (const auto &dim : buf->shape) {
133
134
        usage(dim);
      }
135
      for (const auto &dim : buf->strides) {
136
137
138
139
        usage(dim);
      }
      usage(buf->elem_offset);

140
      for (const auto &buffer : usage.buffer_use_count_) {
141
        if (buffer.second >= 1) {
142
          used_in_buffer_def_.insert(buffer.first);
143
144
        }
      }
145
      for (const auto &buffer : usage.undefined_buffers_) {
146
147
148
149
        used_in_buffer_def_.insert(buffer.get());
      }
    }
    PrimFunc func;
150
    std::unordered_set<const BufferNode *> used_in_buffer_def_;
151
152
153
154
155
156
157
  };

  Visitor visitor(func);
  visitor(func->body);
  return visitor.used_in_buffer_def_;
}

158
159
160
161
162
/* \brief Utility function to collect vars that should be retained. Used in
 * Letstmt Only
 */
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
163
164
165
166
  struct Visitor : StmtExprVisitor {
    using StmtExprVisitor::VisitExpr_;
    using StmtExprVisitor::VisitStmt_;

167
    void VisitExpr_(const BufferLoadNode *op) override {
168
169
170
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitExpr_(op);
    }
171
    void VisitStmt_(const BufferStoreNode *op) override {
172
173
174
175
      VisitBuffer(op->buffer);
      StmtExprVisitor::VisitStmt_(op);
    }

176
    void VisitBuffer(const Buffer &buf) {
177
178
179
      // Collect variables that should remain defined
      VarUseDefAnalyzer usage(Array<Var>{});
      usage(buf->data);
180
      for (const auto &dim : buf->shape) {
181
182
        usage(dim);
      }
183
      for (const auto &dim : buf->strides) {
184
185
186
187
188
        usage(dim);
      }
      usage(buf->elem_offset);

      // Track for use in LetStmtNode mutator
189
      for (const auto &var : usage.undefined_) {
190
191
192
        used_in_buffer_def_.insert(var.get());
      }
    }
193
    std::unordered_set<const VarNode *> used_in_buffer_def_;
194
195
196
197
198
199
200
201
  };

  Visitor visitor;
  visitor(stmt);
  return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
202
203
204
public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
                                            SimplifyConfigNode);
205
206
207
208
209
210
};

TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
211
212
public:
  static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
213
214
                        Optional<SimplifyConfig> config_opt = NullOpt) {
    auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
215
216
    analyzer->rewrite_simplify.SetEnabledExtensions(
        config->GetEnabledExtensions());
217
218
219
220
221
222
223

    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
    if (config->propagate_knowns_to_prove_conditional ||
        config->propagate_knowns_to_simplify_expressions) {
      touch_pattern = ControlFlowGraph(func->body);
    }

224
    std::unordered_set<const VarNode *> used_in_buffer_def =
225
226
227
228
229
230
231
232
233
234
235
236
237
        CollectVarsUsedInBufferDefinition(func->body);
    StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                              std::move(used_in_buffer_def));
    simplifier.MarkBufferMapShapes(func);
    func.CopyOnWrite()->body = simplifier(func->body);

    // Begin to remove useless var and buffer
    // First get used buffers
    simplifier.used_buffers_ = CollectUsedBuffers(func);
    bool param_updated = false;
    Array<Var> new_params;
    Map<Var, Buffer> new_buffer_map;
    // Check whether each buffer is used
238
239
240
241
242
243
244
245
    for (const auto &var : func->params) {
      if (func->buffer_map.find(var) != func->buffer_map.end()) {
        if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
            simplifier.used_buffers_.end()) {
          new_params.push_back(var);
          new_buffer_map.Set(var, func->buffer_map[var]);
        } else {
          param_updated = true;
246
        }
247
      }
248
249
250
    }
    // return func;
    if (param_updated) {
251
252
      return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
                      new_buffer_map, func->attrs, func->span);
253
    } else {
254
      return func;
255
256
257
    }
  }

258
259
260
261
262
263
264
265
private:
  explicit StmtSimplifier(
      Analyzer *analyzer, SimplifyConfig config,
      std::optional<ControlFlowGraph> touch_pattern,
      std::unordered_set<const VarNode *> used_in_buffer_def)
      : IRMutatorWithAnalyzer(analyzer), config_(config),
        touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
  }
266
267
268
269
270
271

  using Parent = IRMutatorWithAnalyzer;
  using Parent::VisitExpr_;
  using Parent::VisitStmt;
  using Parent::VisitStmt_;

272
  PrimExpr VisitExpr(const PrimExpr &expr) final {
273
    if (config_->propagate_knowns_to_simplify_expressions) {
274
275
      return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
                                               analyzer_);
276
277
278
279
280
281
282
    } else {
      return analyzer_->Simplify(expr);
    }
  }

  Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }

283
  Stmt VisitStmt(const Stmt &stmt) override {
284
285
286
287
288
289
290
    Optional<Stmt> cache = this->current_stmt_;
    this->current_stmt_ = stmt;
    Stmt output = Parent::VisitStmt(stmt);
    this->current_stmt_ = std::move(cache);
    return output;
  }

291
  Stmt VisitStmt_(const ForNode *op) final {
292
293
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
    With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
294
295
    With<ConstraintContext> ctx2(analyzer_,
                                 op->loop_var < op->min + op->extent);
296
297
298
    return Parent::VisitStmt_(op);
  }

299
300
301
302
303
  bool CanInlineLetStmt(const LetStmtNode *op) {
    if (is_const_number(op->value))
      return true;
    if (op->value.as<VarNode>())
      return true;
304
    // Won't face the deep expression explosion problem as in Let expression.
305
306
307
308
    // attempt to inline as much as possible if the value integer type(can be
    // index).
    if (!op->value.dtype().is_int())
      return false;
309
310
311
    return SideEffect(op->value) <= CallEffectKind::kPure;
  }

312
  Stmt VisitStmt_(const LetStmtNode *op) override {
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    PrimExpr value = this->VisitExpr(op->value);
    bool can_inline = CanInlineLetStmt(op);
    if (can_inline) {
      // It is usually fine to discard the let binding because the
      // call to simplify will always inline the var.
      //
      // The exception is when the variable is used in a Buffer's
      // definition, as these are not updated by the simplification.
      // After DeclBuffer is required prior to use of a buffer,
      // simplifying can update the buffer definition as well.  The
      // buffer can only be updated at its point of definition,
      // because the points of use may occur within contexts that
      // allow for additional simplifications (e.g. a buffer of shape
      // [i,j] whose first use occurs within "if i==1" should not have
      // its shape simplified to [1,j]).
      analyzer_->Bind(op->var, value);
    } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
      // Even if we aren't replacing all occurrences, they may be
      // necessary for proving conditional statements.
      non_inlined_bindings_.Set(op->var, value);
    }
    Stmt body = this->VisitStmt(op->body);

    // TODO(Lunderberg): Update the Buffer object as part of
    // DeclBuffer updates, which will first require
    // https://github.com/apache/tvm/pull/14778.
    bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

    if (can_inline && !used_in_buffer_def) {
      return body;
    } else if (value.same_as(op->value) && body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      auto n = this->CopyOnWrite(op);
      n->value = std::move(value);
      n->body = std::move(body);
      return Stmt(n);
    }
  }

353
  Stmt VisitStmt_(const IfThenElseNode *op) override {
354
355
356
357
358
359
360
361
362
363
364
365
366
    if (Optional<Bool> cond = ProveCondition(op->condition)) {
      if (cond.value()->value) {
        return this->VisitStmt(op->then_case);
      } else if (op->else_case) {
        return this->VisitStmt(op->else_case.value());
      } else {
        return Evaluate(0);
      }
    } else {
      return Parent::VisitStmt_(op);
    }
  }

367
  PrimExpr VisitExpr_(const CallNode *op) override {
368
369
370
371
372
373
374
375
376
377
378
379
    if (op->op.same_as(builtin::if_then_else())) {
      if (Optional<Bool> cond = ProveCondition(op->args[0])) {
        if (cond.value()->value) {
          return this->VisitExpr(op->args[1]);
        } else {
          return this->VisitExpr(op->args[2]);
        }
      }
    }
    return Parent::VisitExpr_(op);
  }

380
  PrimExpr VisitExpr_(const VarNode *op) override {
381
382
383
384
    used_vars_.insert(op);
    return Parent::VisitExpr_(op);
  }

385
  PrimExpr VisitExpr_(const BufferLoadNode *op) override {
386
387
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
388
      used_buffers_.insert(buffer);
389
390
391
392
393
    }
    return Parent::VisitExpr_(op);
  }

  // eliminate useless stores
394
  Stmt VisitStmt_(const BufferStoreNode *op) override {
395
    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
396
    if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
397
398
      if (load->buffer->data.same_as(store->buffer->data) &&
          ArrayDeepEqual(load->indices, store->indices) &&
399
400
          tir::ExprDeepEqual()(load->buffer->elem_offset,
                               store->buffer->elem_offset) &&
401
402
403
404
405
406
407
          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
        return Evaluate(0);
      }
    }
    auto buffer = op->buffer.get();
    if (used_buffers_.find(buffer) == used_buffers_.end()) {
408
      used_buffers_.insert(buffer);
409
410
411
412
    }
    return std::move(store);
  }

413
414
private:
  bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    if (lhs.size() != rhs.size()) {
      return false;
    }
    for (size_t i = 0; i < lhs.size(); i++) {
      if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
        return false;
      }
    }
    return true;
  }

  /* \brief Internal utility for checking conditionals
   *
   * Uses more aggressive optimization, such as performing additional
   * inlining and tracking known buffer values.
   */
  Optional<Bool> ProveCondition(PrimExpr condition) const {
    condition = Substitute(condition, non_inlined_bindings_);
    if (config_->propagate_knowns_to_prove_conditional) {
      ICHECK(touch_pattern_.has_value());
435
436
      condition = touch_pattern_->SimplifyInContext(
          condition, current_stmt_.value(), analyzer_);
437
438
439
    } else {
      condition = analyzer_->Simplify(condition);
    }
440
    if (const int64_t *as_int = as_const_int(condition)) {
441
442
443
444
445
446
447
448
449
450
451
      return Bool(*as_int);
    } else {
      return NullOpt;
    }
  }

  SimplifyConfig config_;
  std::optional<ControlFlowGraph> touch_pattern_;

  Map<Var, PrimExpr> non_inlined_bindings_;
  Optional<Stmt> current_stmt_{NullOpt};
452
453
454
  std::unordered_set<const VarNode *> used_in_buffer_def_;
  std::unordered_set<const VarNode *> used_vars_;
  std::unordered_set<const BufferNode *> used_buffers_;
455
456
457
458
459
};

using namespace tir::transform;

tvm::transform::Pass Simplify() {
460
461
462
463
464
465
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    arith::Analyzer analyzer;
    auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
    return StmtSimplifier::Apply(f, &analyzer, cfg);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
466
467
468
469
470
471
}

TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);

} // namespace tl
} // namespace tvm