/*! * \file simplify.cc * \brief Statement simplifier based on analyzer and remove useless parameters * of TL PrimFunc. */ #include #include #include #include #include #include #include #include #include #include "arith/ir_mutator_with_analyzer.h" #include "tir/analysis/control_flow_graph.h" #include "tir/analysis/var_use_def_analysis.h" namespace tvm { namespace tl { using namespace tir; using namespace arith; struct SimplifyConfigNode : public AttrsNodeReflAdapter { bool transitively_prove_inequalities{}; bool propagate_knowns_to_prove_conditional{}; bool propagate_knowns_to_simplify_expressions{}; bool convert_boolean_to_and_of_ors{}; bool apply_constraints_to_boolean_branches{}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("transitively_prove_inequalities", &SimplifyConfigNode::transitively_prove_inequalities, "If true, simplify conditionals with transitive combinations " "of scoped constraints", refl::DefaultValue(false)) .def_ro("propagate_knowns_to_prove_conditional", &SimplifyConfigNode::propagate_knowns_to_prove_conditional, "If true, known buffer values are propagated and used to " "statically prove conditionals", refl::DefaultValue(false)) .def_ro("propagate_knowns_to_simplify_expressions", &SimplifyConfigNode::propagate_knowns_to_simplify_expressions, "If true, known buffer values are propagated and used to " "replace BufferLoad wherever " "possible", refl::DefaultValue(false)) .def_ro("convert_boolean_to_and_of_ors", &SimplifyConfigNode::convert_boolean_to_and_of_ors, "If true, simplify conditionals into an AND of ORs", refl::DefaultValue(false)) .def_ro("apply_constraints_to_boolean_branches", &SimplifyConfigNode::apply_constraints_to_boolean_branches, "If true, simplify each branch of AND/OR under a constraints " "provided by the other " "branch", refl::DefaultValue(false)); } static constexpr const char *_type_key = "tl.transform.SimplifyConfig"; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode); RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; if (transitively_prove_inequalities) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kTransitivelyProveInequalities); } if (convert_boolean_to_and_of_ors) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); } if (apply_constraints_to_boolean_branches) { flags = RewriteSimplifier::Extension( flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches); } return flags; } }; std::unordered_set CollectUsedBuffers(const PrimFunc &func) { struct Visitor : StmtExprVisitor { using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; Visitor(PrimFunc func) : func(std::move(func)) {} void VisitExpr_(const CallNode *op) override { for (const auto &arg : op->args) { for (const auto &it : func->buffer_map) { if (Downcast(it.second.get()->data).same_as(arg)) { used_in_buffer_def_.insert(it.second.get()); } } } StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const BufferLoadNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const BlockNode *op) override { for (const auto &buffer : op->alloc_buffers) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } for (const auto &buffer : op->reads) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } for (const auto &buffer : op->writes) { for (const auto &it : func->buffer_map) { if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { used_in_buffer_def_.insert(it.second.get()); } } } StmtExprVisitor::VisitStmt_(op); } void VisitBuffer(const Buffer &buf) { // Collect buffers that should remain defined VarUseDefAnalyzer usage(Array{}); usage(buf->data); for (const auto &dim : buf->shape) { usage(dim); } for (const auto &dim : buf->strides) { usage(dim); } usage(buf->elem_offset); for (const auto &buffer : usage.buffer_use_count_) { if (buffer.second >= 1) { used_in_buffer_def_.insert(buffer.first); } } for (const auto &buffer : usage.undefined_buffers_) { used_in_buffer_def_.insert(buffer.get()); } } PrimFunc func; std::unordered_set used_in_buffer_def_; }; Visitor visitor(func); visitor(func->body); return visitor.used_in_buffer_def_; } /* \brief Utility function to collect vars that should be retained. Used in * Letstmt Only */ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt &stmt) { struct Visitor : StmtExprVisitor { using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; void VisitExpr_(const BufferLoadNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) override { VisitBuffer(op->buffer); StmtExprVisitor::VisitStmt_(op); } void VisitBuffer(const Buffer &buf) { // Collect variables that should remain defined VarUseDefAnalyzer usage(Array{}); usage(buf->data); for (const auto &dim : buf->shape) { usage(dim); } for (const auto &dim : buf->strides) { usage(dim); } usage(buf->elem_offset); // Track for use in LetStmtNode mutator for (const auto &var : usage.undefined_) { used_in_buffer_def_.insert(var.get()); } } std::unordered_set used_in_buffer_def_; }; Visitor visitor; visitor(stmt); return visitor.used_in_buffer_def_; } class SimplifyConfig : public Attrs { public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); }; TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(SimplifyConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer *analyzer, const Optional &config_opt = std::nullopt, bool simplify_arguments = false) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions( config->GetEnabledExtensions()); std::optional touch_pattern = std::nullopt; if (config->propagate_knowns_to_prove_conditional || config->propagate_knowns_to_simplify_expressions) { touch_pattern = ControlFlowGraph(func->body); } std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(func->body); StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), std::move(used_in_buffer_def)); simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); // Begin to remove useless var and buffer // First get used buffers simplifier.used_buffers_ = CollectUsedBuffers(func); bool param_updated = false; Array new_params; Map new_buffer_map; // Check whether each buffer is used for (const auto &var : func->params) { if (func->buffer_map.find(var) != func->buffer_map.end()) { if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != simplifier.used_buffers_.end()) { new_params.push_back(var); new_buffer_map.Set(var, func->buffer_map[var]); } else if (simplifier.used_in_buffer_def_.find( func->buffer_map[var]->data.get()) != simplifier.used_in_buffer_def_.end()) { new_params.push_back(var); new_buffer_map.Set(var, func->buffer_map[var]); } else { param_updated = true; } } } if (param_updated) { return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, new_buffer_map, func->attrs, func->span); } else { return func; } } private: explicit StmtSimplifier( Analyzer *analyzer, SimplifyConfig config, std::optional touch_pattern, std::unordered_set used_in_buffer_def) : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)), touch_pattern_(std::move(touch_pattern)), used_in_buffer_def_(std::move(used_in_buffer_def)) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitExpr_; using Parent::VisitStmt; using Parent::VisitStmt_; PrimExpr VisitExpr(const PrimExpr &expr) final { if (config_->propagate_knowns_to_simplify_expressions) { return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); } else { return analyzer_->Simplify(expr); } } Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt(const Stmt &stmt) override { Optional cache = this->current_stmt_; this->current_stmt_ = stmt; Stmt output = Parent::VisitStmt(stmt); this->current_stmt_ = std::move(cache); return output; } Stmt VisitStmt_(const ForNode *op) final { analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); With ctx1(analyzer_, op->loop_var >= op->min); With ctx2(analyzer_, op->loop_var < op->min + op->extent); return Parent::VisitStmt_(op); } bool CanInlineLetStmt(const LetStmtNode *op) { if (is_const_number(op->value)) return true; if (op->value.as()) return true; // Won't face the deep expression explosion problem as in Let expression. // attempt to inline as much as possible if the value integer type(can be // index). if (!op->value.dtype().is_int()) return false; return SideEffect(op->value) <= CallEffectKind::kPure; } Stmt VisitStmt_(const LetStmtNode *op) override { PrimExpr value = this->VisitExpr(op->value); bool remove_buffer_alias = false; // TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate // fragment types. TVM currently reinterprets vectorized/shared accesses as // Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later // passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes // that our layout can't handle. Force-inline (by dropping the let) whenever // the alias spans more than 2 dims or carries vector lanes. auto get_ranges = [&](const PrimExpr &expr) -> Array { Array ranges; if (const auto *load = expr.as()) { for (const PrimExpr &index : load->indices) { if (const auto *ramp = index.as()) { ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); } else { ranges.push_back(Range::FromMinExtent(index, Integer(1))); } } } else if (const auto *region = expr.as()) { for (const Range &range : region->region) { ranges.push_back(range); } } return ranges; }; Array ranges = get_ranges(value); if (!ranges.empty()) { int non_unit_dims = 0; for (const Range &range : ranges) { PrimExpr extent = analyzer_->Simplify(range->extent); if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) { continue; } ++non_unit_dims; if (non_unit_dims > 1) { remove_buffer_alias = true; break; } } } if (remove_buffer_alias) { Stmt body = this->VisitStmt(op->body); bool used = UsesVar( body, [&](const VarNode *var) { return var == op->var.get(); }); ICHECK(!used) << "Let binding of BufferLoad is expected to be unused " "before removal " << op->var << " : " << op->value << " ."; return body; } bool can_inline = CanInlineLetStmt(op); if (can_inline) { analyzer_->Bind(op->var, value); } else if (SideEffect(op->value) <= CallEffectKind::kPure) { non_inlined_bindings_.Set(op->var, value); } Stmt body = this->VisitStmt(op->body); bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); if (can_inline && !used_in_buffer_def) { return body; } else if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); n->body = std::move(body); return Stmt(n); } } Stmt VisitStmt_(const IfThenElseNode *op) override { if (Optional cond = ProveCondition(op->condition)) { if (cond.value()->value) { return this->VisitStmt(op->then_case); } else if (op->else_case) { return this->VisitStmt(op->else_case.value()); } else { return Evaluate(0); } } else { return Parent::VisitStmt_(op); } } PrimExpr VisitExpr_(const CallNode *op) override { if (op->op.same_as(builtin::if_then_else())) { if (Optional cond = ProveCondition(op->args[0])) { if (cond.value()->value) { return this->VisitExpr(op->args[1]); } else { return this->VisitExpr(op->args[2]); } } } return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode *op) override { used_vars_.insert(op); return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const BufferLoadNode *op) override { auto buffer = op->buffer.get(); if (used_buffers_.find(buffer) == used_buffers_.end()) { used_buffers_.insert(buffer); } return Parent::VisitExpr_(op); } // eliminate useless stores Stmt VisitStmt_(const BufferStoreNode *op) override { BufferStore store = Downcast(Parent::VisitStmt_(op)); if (const BufferLoadNode *load = store->value.as()) { if (load->buffer->data.same_as(store->buffer->data) && ArrayDeepEqual(load->indices, store->indices) && tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { return Evaluate(0); } } auto buffer = op->buffer.get(); if (used_buffers_.find(buffer) == used_buffers_.end()) { used_buffers_.insert(buffer); } return std::move(store); } private: bool ArrayDeepEqual(const Array &lhs, const Array &rhs) { if (lhs.size() != rhs.size()) { return false; } for (size_t i = 0; i < lhs.size(); i++) { if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { return false; } } return true; } /* \brief Internal utility for checking conditionals * * Uses more aggressive optimization, such as performing additional * inlining and tracking known buffer values. */ Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); if (config_->propagate_knowns_to_prove_conditional) { ICHECK(touch_pattern_.has_value()); condition = touch_pattern_->SimplifyInContext( condition, current_stmt_.value(), analyzer_); } else { condition = analyzer_->Simplify(condition); } if (const int64_t *as_int = as_const_int(condition)) { return Bool(*as_int); } else { // May have symbolic, need kSymbolicBound level prover. if (analyzer_->CanProve(condition) || analyzer_->CanProve(condition, arith::ProofStrength::kSymbolicBound)) { return Bool(true); } return std::nullopt; } } SimplifyConfig config_; std::optional touch_pattern_; Map non_inlined_bindings_; Optional current_stmt_{std::nullopt}; std::unordered_set used_in_buffer_def_; std::unordered_set used_vars_; std::unordered_set used_buffers_; }; using namespace tir::transform; tvm::transform::Pass Simplify(bool simplify_arguments = true) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { arith::Analyzer analyzer; auto cfg = ctx->GetConfig("tl.Simplify"); return StmtSimplifier::Apply(std::move(f), &analyzer, cfg, simplify_arguments); }; return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.Simplify", Simplify); }); } // namespace tl } // namespace tvm