/*! * \file legalize_safe_memory_access.cc * \brief legalize safe memory access */ #include #include #include #include #include #include #include #include "../op/builtin.h" #include "../op/parallel.h" #include "arith/ir_mutator_with_analyzer.h" #include "loop_partition.h" #include "loop_vectorize.h" namespace tvm { namespace tl { using namespace tir; using arith::IRMutatorWithAnalyzer; // Helper class to find leaf For nodes in a given IR class LeafForFinder : public StmtVisitor { public: std::vector leaf_for_nodes; private: void VisitStmt_(const ForNode *op) final { has_child_for_ = false; bool parent_has_child_for = parent_has_child_for_; parent_has_child_for_ = false; StmtVisitor::VisitStmt(op->body); if (!has_child_for_) { leaf_for_nodes.push_back(GetRef(op)); } parent_has_child_for_ = parent_has_child_for; parent_has_child_for_ = true; } private: bool has_child_for_ = false; bool parent_has_child_for_ = false; }; // GlobalMemChecker for a BufferLoad/BufferStore node: // 1. Identify BufferLoad and BufferStore nodes. // 2. Check if the buffer is in global scope. // 3. For each index, compare against the buffer's shape. // If the index might exceed the shape (upper bound too large), // log a warning or handle accordingly. struct GlobalMemChecker : public StmtExprVisitor { GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds) : analyzer_(analyzer), recursively_collect_conds_(recursively_collect_conds) {} void VisitExpr_(const BufferLoadNode *op) final { // Check if the buffer is in global scope // This is because we are writing TilePrograms, where out of bounds // accesses only happen in the global buffer. if (IsGlobalBuffer(op->buffer)) { CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true); } if (recursively_collect_conds_) { StmtExprVisitor::VisitExpr_(op); } } void VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope if (IsGlobalBuffer(op->buffer)) { CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false); } if (recursively_collect_conds_) { StmtExprVisitor::VisitStmt_(op); } } // Helper function to determine if a buffer is global bool IsGlobalBuffer(const Buffer &buffer) { // The storage scope is often encoded in the buffer->data var name or // associated attributes. In typical TVM IR, global buffers have scope // "global". Here we assume a helper function GetPtrStorageScope is // available. If not, you might need to parse buffer->data->name_hint or // associated attributes. String scope = buffer.scope(); return scope == "global"; } // Check each index against the buffer shape dimensions void CheckBufferIndices(const Buffer &buffer, const Array &indices, bool is_load) { // Ensure indices count matches buffer dimension if (indices.size() != buffer->shape.size()) { LOG(WARNING) << "Buffer access dimension mismatch: indices size (" << indices.size() << ") vs. shape size (" << buffer->shape.size() << ")"; return; } for (size_t i = 0; i < indices.size(); i++) { PrimExpr index = indices[i]; PrimExpr shape_dim = buffer->shape[i]; bool has_variable = false; PostOrderVisit(index, [&](const ObjectRef &obj) { if (const VarNode *v = obj.as()) { has_variable = true; } }); if (!has_variable) { // If index is a constant, we can skip the check continue; } // We want to check if index < shape_dim can be proven. // If analyzer->CanProve(index < shape_dim) returns false, // it means we cannot prove the access is within bounds. PrimExpr upper_bound_cond = index < shape_dim; if (!analyzer_->CanProve(upper_bound_cond, arith::ProofStrength::kSymbolicBound)) { _conditions.push_back(upper_bound_cond); } // Check if index >= 0 can be proven. PrimExpr lower_bound_cond = index >= 0; if (!analyzer_->CanProve(lower_bound_cond, arith::ProofStrength::kSymbolicBound)) { _conditions.push_back(lower_bound_cond); } } } Array GetConditions() { return _conditions; } private: Array _conditions; arith::Analyzer *analyzer_; bool recursively_collect_conds_; }; class SafeMemorysRewriter : public StmtExprMutator { arith::Analyzer *analyzer_; public: explicit SafeMemorysRewriter(Map annotated_safe_value_map, arith::Analyzer *analyzer) : annotated_safe_value_map_(std::move(annotated_safe_value_map)), analyzer_(analyzer) {} private: PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); // For Load/Store, we only check the current node, not its children. // Since rewriter will recursively visit children. GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); checker(load); Array conditions = checker.GetConditions(); if (conditions.empty()) { return load; } // For loading, we can always use safe value if the access is out of // bounds PrimExpr value = load; for (auto cond : conditions) { ICHECK(cond.dtype() == DataType::Bool(1)) << "condition is not a boolean: " << cond; value = if_then_else(cond, value, GetSafeValue(load->buffer)); } return value; } Stmt VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope auto store = Downcast(StmtExprMutator::VisitStmt_(op)); GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); checker(store); Array conditions = checker.GetConditions(); // Skip boundary check if the store value is an IfThenElse if (const IfThenElseNode *if_node = store->value.as()) { if (!conditions.empty()) { LOG(WARNING) << "Skipping boundary check for store with IfThenElse value: " << store->value << "\nAs manual boundary check detected, potential out-of-bounds " "access may occur." << "\nAuto detect boundaries are " << conditions; return store; } return store; } if (conditions.empty()) { return store; } // If a store is out of bounds, we skip the corresponding stmt directly. Stmt store_with_conditions = store; for (auto cond : conditions) { store_with_conditions = IfThenElse(cond, store_with_conditions); } return store_with_conditions; } // Recursively check Load/Store in the call arguments. // For example // T.call_extern("handle", "atomicAddx2", T.address_of(C), // T.address_of(C_shared)) // NOTE(chaofan): This is currently not the most rigorous solution. // The check here is primarily intended to handle extern functions like // atomicAdd, which may involve memory access. Due to their special nature, // the BufferLoad in their parameters might be used for boundary checks of the // current statement. The current solution adopts a simplified approach: // directly applying the boundary constraints of all parameters to the // statement. While not entirely precise, it addresses most common scenarios. Stmt VisitStmt_(const EvaluateNode *op) final { auto evaluate = Downcast(op); if (const CallNode *call_op = op->value.as()) { auto call = Downcast(op->value); if (call->op == builtin::call_extern()) { // For CallExtern, we recursively collect conditions from all children. // Since we cannot rewrite any BufferLoad in its children (Rewrite will // cause potential Nullptr exception). GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true); checker(call); Array conditions = checker.GetConditions(); if (conditions.empty()) { return evaluate; } Stmt evaluate_with_conditions = evaluate; for (auto cond : conditions) { evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions); } return evaluate_with_conditions; } } return evaluate; } bool IsLocalBuffer(const Buffer &buffer) { String scope = buffer.scope(); return scope == "local" || scope == "local.fragment" || scope == "local.var"; } bool isSharedBuffer(const Buffer &buffer) { String scope = buffer.scope(); return scope == "shared" || scope == "shared.dyn"; } bool IsGlobalBuffer(const Buffer &buffer) { String scope = buffer.scope(); return scope == "global"; } // Get the safe value of the buffer PrimExpr GetSafeValue(const Buffer &buffer) { if (annotated_safe_value_map_.count(buffer)) { return annotated_safe_value_map_[buffer]; } return make_zero(buffer->dtype); } Map annotated_safe_value_map_; }; // Class to legalize safe memory access by transforming them appropriately class SafeMemoryLegalizer : IRMutatorWithAnalyzer { public: // Static method to substitute and transform the given PrimFunc static PrimFunc Substitute(PrimFunc f) { arith::Analyzer analyzer; // Create an instance of the legalizer with the analyzer SafeMemoryLegalizer substituter(&analyzer); // Get a mutable copy of the function node PrimFuncNode *fptr = f.CopyOnWrite(); for (const auto &[_, buffer] : f->buffer_map) { substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); } // Apply the legalizer to the function body fptr->body = substituter.VisitStmt(f->body); return f; } private: // Constructor initializing the base class with the analyzer SafeMemoryLegalizer(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} // Override the VisitStmt_ method to handle ForNode (loop statements) Stmt VisitStmt_(const ForNode *op) final { // Visit and potentially modify the loop node For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto has_inner_loop = HasInnerLoop(for_node->body); if (!has_inner_loop) { SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_); for_node.CopyOnWrite()->body = rewriter(for_node->body); // // Detect Buffer Load Node in the loop body, collect the indices and // buffer size // // Run the checker on the loop body // GlobalMemChecker checker(analyzer_); // checker(for_node->body); // Array conditions = checker.GetConditions(); // auto body = for_node->body; // // Note that we might have duplicate conditions // // Which will be optimized by simplify pass // // Replace the loop body with the new body // for (auto cond : conditions) { // body = IfThenElse(cond, body); // } // for_node.CopyOnWrite()->body = body; return std::move(for_node); } // Visit a For Node return IRMutatorWithAnalyzer::VisitStmt_(op); } Stmt VisitStmt_(const BlockNode *op) final { for (auto buffer : op->alloc_buffers) { buffer_data_to_buffer_.Set(buffer->data, buffer); } if (op->annotations.count(attr::kSafeValueMap)) { auto map = op->annotations.Get(attr::kSafeValueMap) ->as>() .value(); for (const auto &[var, safe_value] : map) { ICHECK(buffer_data_to_buffer_.count(var)) << "buffer " << var << " is not found in the block " << buffer_data_to_buffer_; auto buffer = buffer_data_to_buffer_[var]; annotated_safe_value_map_.Set(buffer, safe_value); } } return IRMutatorWithAnalyzer::VisitStmt_(op); } static bool HasInnerLoop(const Stmt &stmt) { LeafForFinder finder; finder(stmt); return !finder.leaf_for_nodes.empty(); } Map buffer_data_to_buffer_; Map annotated_safe_value_map_; }; // Create a pass that legalizes vectorized loops in the IRModule tvm::transform::Pass LegalizeSafeMemoryAccess() { using namespace tir::transform; // Define the transformation function to be applied auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { bool disable_safe_memory_legalize = ctx->GetConfig(kDisableSafeMemoryLegalize, Bool(false)).value(); if (disable_safe_memory_legalize) { return f; } return SafeMemoryLegalizer::Substitute(std::move(f)); }; // Create and return a PrimFunc pass with the transformation function return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {}); } // Register the pass globally so it can be used in the compilation pipeline TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess", LegalizeSafeMemoryAccess); }); } // namespace tl } // namespace tvm