/*! * \file simplify.cc * \brief Remove useless parameters of TL PrimFunc. */ #include #include #include #include #include #include "arith/ir_mutator_with_analyzer.h" #include "tir/analysis/control_flow_graph.h" #include "tir/analysis/var_use_def_analysis.h" namespace tvm { namespace tl { using namespace tir; using namespace arith; struct SimplifyConfigNode : public tvm::AttrsNode { bool transitively_prove_inequalities; bool propagate_knowns_to_prove_conditional; bool propagate_knowns_to_simplify_expressions; bool convert_boolean_to_and_of_ors; bool apply_constraints_to_boolean_branches; TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") { TVM_ATTR_FIELD(transitively_prove_inequalities) .describe("If true, simplify conditionals with transitive combinations " "of scoped constraints") .set_default(false); TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional) .describe("If true, known buffer values are propagated and used to " "statically prove conditionals") .set_default(false); TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions) .describe("If true, known buffer values are propagated and used to " "replace BufferLoad wherever " "possible") .set_default(false); TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) .describe("If true, simplify conditionals into an AND of ORs") .set_default(false); TVM_ATTR_FIELD(apply_constraints_to_boolean_branches) .describe("If true, simplify each branch of AND/OR " "under a constraints provided by the other branch") .set_default(false); } RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; if (transitively_prove_inequalities) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kTransitivelyProveInequalities); } if (convert_boolean_to_and_of_ors) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); } if (apply_constraints_to_boolean_branches) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches); } return flags; } }; std::unordered_set CollectUsedBuffers(const PrimFunc &func) { struct Visitor : StmtExprVisitor { using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; Visitor(PrimFunc func) : func(func) {} void VisitExpr_(const CallNode *op) override { for (const auto &arg : op->args) { for (const auto &it : func->buffer_map) { if (Downcast(it.second.get()->data).same_as(arg)) { used_in_buffer_def_.insert(it.second.get()); } } } StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const BufferLoadNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const BlockNode *op) override { for (const auto &buffer : op->alloc_buffers) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } for (const auto &buffer : op->reads) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } for (const auto &buffer : op->writes) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } StmtExprVisitor::VisitStmt_(op); } void VisitBuffer(const Buffer &buf) { // Collect buffers that should remain defined VarUseDefAnalyzer usage(Array{}); usage(buf->data); for (const auto &dim : buf->shape) { usage(dim); } for (const auto &dim : buf->strides) { usage(dim); } usage(buf->elem_offset); for (const auto &buffer : usage.buffer_use_count_) { if (buffer.second >= 1) { used_in_buffer_def_.insert(buffer.first); } } for (const auto &buffer : usage.undefined_buffers_) { used_in_buffer_def_.insert(buffer.get()); } } PrimFunc func; std::unordered_set used_in_buffer_def_; }; Visitor visitor(func); visitor(func->body); return visitor.used_in_buffer_def_; } /* \brief Utility function to collect vars that should be retained. Used in * Letstmt Only */ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt &stmt) { struct Visitor : StmtExprVisitor { using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; void VisitExpr_(const BufferLoadNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitStmt_(op); } void VisitBuffer(const Buffer &buf) { // Collect variables that should remain defined VarUseDefAnalyzer usage(Array{}); usage(buf->data); for (const auto &dim : buf->shape) { usage(dim); } for (const auto &dim : buf->strides) { usage(dim); } usage(buf->elem_offset); // Track for use in LetStmtNode mutator for (const auto &var : usage.undefined_) { used_in_buffer_def_.insert(var.get()); } } std::unordered_set used_in_buffer_def_; }; Visitor visitor; visitor(stmt); return visitor.used_in_buffer_def_; } class SimplifyConfig : public Attrs { public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); }; TVM_REGISTER_NODE_TYPE(SimplifyConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer *analyzer, Optional config_opt = NullOpt) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions( config->GetEnabledExtensions()); std::optional touch_pattern = std::nullopt; if (config->propagate_knowns_to_prove_conditional || config->propagate_knowns_to_simplify_expressions) { touch_pattern = ControlFlowGraph(func->body); } std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(func->body); StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), std::move(used_in_buffer_def)); simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); // Begin to remove useless var and buffer // First get used buffers simplifier.used_buffers_ = CollectUsedBuffers(func); bool param_updated = false; Array new_params; Map new_buffer_map; // Check whether each buffer is used for (const auto &var : func->params) { if (func->buffer_map.find(var) != func->buffer_map.end()) { if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != simplifier.used_buffers_.end()) { new_params.push_back(var); new_buffer_map.Set(var, func->buffer_map[var]); } else { param_updated = true; } } } // return func; if (param_updated) { return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, new_buffer_map, func->attrs, func->span); } else { return func; } } private: explicit StmtSimplifier( Analyzer *analyzer, SimplifyConfig config, std::optional touch_pattern, std::unordered_set used_in_buffer_def) : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) { } using Parent = IRMutatorWithAnalyzer; using Parent::VisitExpr_; using Parent::VisitStmt; using Parent::VisitStmt_; PrimExpr VisitExpr(const PrimExpr &expr) final { if (config_->propagate_knowns_to_simplify_expressions) { return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); } else { return analyzer_->Simplify(expr); } } Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt(const Stmt &stmt) override { Optional cache = this->current_stmt_; this->current_stmt_ = stmt; Stmt output = Parent::VisitStmt(stmt); this->current_stmt_ = std::move(cache); return output; } Stmt VisitStmt_(const ForNode *op) final { analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); With ctx1(analyzer_, op->loop_var >= op->min); With ctx2(analyzer_, op->loop_var < op->min + op->extent); return Parent::VisitStmt_(op); } bool CanInlineLetStmt(const LetStmtNode *op) { if (is_const_number(op->value)) return true; if (op->value.as()) return true; // Won't face the deep expression explosion problem as in Let expression. // attempt to inline as much as possible if the value integer type(can be // index). if (!op->value.dtype().is_int()) return false; return SideEffect(op->value) <= CallEffectKind::kPure; } Stmt VisitStmt_(const LetStmtNode *op) override { PrimExpr value = this->VisitExpr(op->value); bool can_inline = CanInlineLetStmt(op); if (can_inline) { // It is usually fine to discard the let binding because the // call to simplify will always inline the var. // // The exception is when the variable is used in a Buffer's // definition, as these are not updated by the simplification. // After DeclBuffer is required prior to use of a buffer, // simplifying can update the buffer definition as well. The // buffer can only be updated at its point of definition, // because the points of use may occur within contexts that // allow for additional simplifications (e.g. a buffer of shape // [i,j] whose first use occurs within "if i==1" should not have // its shape simplified to [1,j]). analyzer_->Bind(op->var, value); } else if (SideEffect(op->value) <= CallEffectKind::kPure) { // Even if we aren't replacing all occurrences, they may be // necessary for proving conditional statements. non_inlined_bindings_.Set(op->var, value); } Stmt body = this->VisitStmt(op->body); // TODO(Lunderberg): Update the Buffer object as part of // DeclBuffer updates, which will first require // https://github.com/apache/tvm/pull/14778. bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); if (can_inline && !used_in_buffer_def) { return body; } else if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); n->body = std::move(body); return Stmt(n); } } Stmt VisitStmt_(const IfThenElseNode *op) override { if (Optional cond = ProveCondition(op->condition)) { if (cond.value()->value) { return this->VisitStmt(op->then_case); } else if (op->else_case) { return this->VisitStmt(op->else_case.value()); } else { return Evaluate(0); } } else { return Parent::VisitStmt_(op); } } PrimExpr VisitExpr_(const CallNode *op) override { if (op->op.same_as(builtin::if_then_else())) { if (Optional cond = ProveCondition(op->args[0])) { if (cond.value()->value) { return this->VisitExpr(op->args[1]); } else { return this->VisitExpr(op->args[2]); } } } return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode *op) override { used_vars_.insert(op); return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const BufferLoadNode *op) override { auto buffer = op->buffer.get(); if (used_buffers_.find(buffer) == used_buffers_.end()) { used_buffers_.insert(buffer); } return Parent::VisitExpr_(op); } // eliminate useless stores Stmt VisitStmt_(const BufferStoreNode *op) override { BufferStore store = Downcast(Parent::VisitStmt_(op)); if (const BufferLoadNode *load = store->value.as()) { if (load->buffer->data.same_as(store->buffer->data) && ArrayDeepEqual(load->indices, store->indices) && tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { return Evaluate(0); } } auto buffer = op->buffer.get(); if (used_buffers_.find(buffer) == used_buffers_.end()) { used_buffers_.insert(buffer); } return std::move(store); } private: bool ArrayDeepEqual(const Array &lhs, const Array &rhs) { if (lhs.size() != rhs.size()) { return false; } for (size_t i = 0; i < lhs.size(); i++) { if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { return false; } } return true; } /* \brief Internal utility for checking conditionals * * Uses more aggressive optimization, such as performing additional * inlining and tracking known buffer values. */ Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); if (config_->propagate_knowns_to_prove_conditional) { ICHECK(touch_pattern_.has_value()); condition = touch_pattern_->SimplifyInContext( condition, current_stmt_.value(), analyzer_); } else { condition = analyzer_->Simplify(condition); } if (const int64_t *as_int = as_const_int(condition)) { return Bool(*as_int); } else { return NullOpt; } } SimplifyConfig config_; std::optional touch_pattern_; Map non_inlined_bindings_; Optional current_stmt_{NullOpt}; std::unordered_set used_in_buffer_def_; std::unordered_set used_vars_; std::unordered_set used_buffers_; }; using namespace tir::transform; tvm::transform::Pass Simplify() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { arith::Analyzer analyzer; auto cfg = ctx->GetConfig("tl.Simplify"); return StmtSimplifier::Apply(f, &analyzer, cfg); }; return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); } TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify); } // namespace tl } // namespace tvm