Unverified Commit 470eb74c authored by LJC00118's avatar LJC00118 Committed by GitHub
Browse files

Improve memory access safety and `T.assume` handling (#1292)



* Improve memory access safety and T.assume handling

* Improve memory access safety and T.assume handling

* bugfix

* lint fix

* bugfix

* bugfix

* refactor legalize safe memory access pass

---------
Co-authored-by: default avatarLei Wang <leiwang1999@outlook.com>
parent 0d101c11
...@@ -24,32 +24,6 @@ namespace tl { ...@@ -24,32 +24,6 @@ namespace tl {
using namespace tir; using namespace tir;
using arith::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer;
// Helper class to find leaf For nodes in a given IR
class LeafForFinder : public StmtVisitor {
public:
std::vector<For> leaf_for_nodes;
private:
void VisitStmt_(const ForNode *op) final {
has_child_for_ = false;
bool parent_has_child_for = parent_has_child_for_;
parent_has_child_for_ = false;
StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) {
leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
}
parent_has_child_for_ = parent_has_child_for;
parent_has_child_for_ = true;
}
private:
bool has_child_for_ = false;
bool parent_has_child_for_ = false;
};
// GlobalMemChecker for a BufferLoad/BufferStore node: // GlobalMemChecker for a BufferLoad/BufferStore node:
// 1. Identify BufferLoad and BufferStore nodes. // 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope. // 2. Check if the buffer is in global scope.
...@@ -109,13 +83,16 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -109,13 +83,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
PrimExpr index = indices[i]; PrimExpr index = indices[i];
PrimExpr shape_dim = buffer->shape[i]; PrimExpr shape_dim = buffer->shape[i];
bool has_variable = false; bool is_index_constant = true;
PostOrderVisit(index, [&](const ObjectRef &obj) { PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) { if (const VarNode *v = obj.as<VarNode>()) {
has_variable = true; is_index_constant = false;
}
if (const BufferLoadNode *v = obj.as<BufferLoadNode>()) {
is_index_constant = false;
} }
}); });
if (!has_variable) { if (is_index_constant) {
// If index is a constant, we can skip the check // If index is a constant, we can skip the check
continue; continue;
} }
...@@ -145,18 +122,31 @@ private: ...@@ -145,18 +122,31 @@ private:
bool recursively_collect_conds_; bool recursively_collect_conds_;
}; };
class SafeMemorysRewriter : public StmtExprMutator { class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
arith::Analyzer *analyzer_;
public: public:
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_safe_value_map, // Static method to substitute and transform the given PrimFunc
arith::Analyzer *analyzer) static PrimFunc Substitute(PrimFunc f) {
: annotated_safe_value_map_(std::move(annotated_safe_value_map)), arith::Analyzer analyzer;
analyzer_(analyzer) {} // Create an instance of the legalizer with the analyzer
SafeMemorysRewriter substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode *fptr = f.CopyOnWrite();
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private: private:
// Constructor initializing the base class with the analyzer
SafeMemorysRewriter(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Constructor initializing the base class with the analyzer
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
// For Load/Store, we only check the current node, not its children. // For Load/Store, we only check the current node, not its children.
// Since rewriter will recursively visit children. // Since rewriter will recursively visit children.
...@@ -181,7 +171,7 @@ private: ...@@ -181,7 +171,7 @@ private:
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(store); checker(store);
...@@ -253,6 +243,25 @@ private: ...@@ -253,6 +243,25 @@ private:
return evaluate; return evaluate;
} }
Stmt VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kSafeValueMap)) {
auto map = op->annotations.Get(attr::kSafeValueMap)
->as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, safe_value] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block "
<< buffer_data_to_buffer_;
auto buffer = buffer_data_to_buffer_[var];
annotated_safe_value_map_.Set(buffer, safe_value);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
bool IsLocalBuffer(const Buffer &buffer) { bool IsLocalBuffer(const Buffer &buffer) {
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "local" || scope == "local.fragment" || return scope == "local" || scope == "local.fragment" ||
...@@ -276,87 +285,6 @@ private: ...@@ -276,87 +285,6 @@ private:
return make_zero(buffer->dtype); return make_zero(buffer->dtype);
} }
Map<Buffer, PrimExpr> annotated_safe_value_map_;
};
// Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode *fptr = f.CopyOnWrite();
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
// Constructor initializing the base class with the analyzer
SafeMemoryLegalizer(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode *op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
// // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_);
// checker(for_node->body);
// Array<PrimExpr> conditions = checker.GetConditions();
// auto body = for_node->body;
// // Note that we might have duplicate conditions
// // Which will be optimized by simplify pass
// // Replace the loop body with the new body
// for (auto cond : conditions) {
// body = IfThenElse(cond, body);
// }
// for_node.CopyOnWrite()->body = body;
return std::move(for_node);
}
// Visit a For Node
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kSafeValueMap)) {
auto map = op->annotations.Get(attr::kSafeValueMap)
->as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, safe_value] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block "
<< buffer_data_to_buffer_;
auto buffer = buffer_data_to_buffer_[var];
annotated_safe_value_map_.Set(buffer, safe_value);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder;
finder(stmt);
return !finder.leaf_for_nodes.empty();
}
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, PrimExpr> annotated_safe_value_map_; Map<Buffer, PrimExpr> annotated_safe_value_map_;
}; };
...@@ -371,7 +299,7 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -371,7 +299,7 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
if (disable_safe_memory_legalize) { if (disable_safe_memory_legalize) {
return f; return f;
} }
return SafeMemoryLegalizer::Substitute(std::move(f)); return SafeMemorysRewriter::Substitute(std::move(f));
}; };
// Create and return a PrimFunc pass with the transformation function // Create and return a PrimFunc pass with the transformation function
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {});
......
...@@ -465,6 +465,16 @@ private: ...@@ -465,6 +465,16 @@ private:
return std::move(store); return std::move(store);
} }
Stmt VisitStmt_(const AttrStmtNode *op) override {
if (op->attr_key == "tl.assume") {
PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
auto n = CopyOnWrite(op);
n->node = std::move(condition);
return Parent::VisitStmt_(n.get());
}
return Parent::VisitStmt_(op);
}
private: private:
bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) { bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
if (lhs.size() != rhs.size()) { if (lhs.size() != rhs.size()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment