Commit 84ddb9e1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update GlobalMemChecker to Detect Lower Bound illegal memory access automatically (#505)

* [Refactor] Update GlobalMemChecker to use IRVisitorWithAnalyzer for improved analysis (#505)

* Refactored GlobalMemChecker to inherit from IRVisitorWithAnalyzer, enhancing its capabilities for expression analysis.
* Updated condition checks to utilize the new analyzer interface, improving clarity and correctness in memory access validation.
* Added additional lower bound condition checks to ensure comprehensive validation of memory access indices.

* [Refactor] Update GlobalMemChecker to use StmtExprVisitor for improved memory access validation

* Refactored GlobalMemChecker to inherit from StmtExprVisitor, enhancing its capabilities for expression analysis.
* Updated condition checks to utilize the new analyzer interface, improving clarity and correctness in memory access validation.
* Ensured that the analyzer is passed correctly during instantiation, maintaining consistency in condition checks.
parent c59e1aab
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <queue>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -57,10 +55,8 @@ private: ...@@ -57,10 +55,8 @@ private:
// If the index might exceed the shape (upper bound too large), // If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly. // log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor { struct GlobalMemChecker : public StmtExprVisitor {
arith::Analyzer *analyzer;
explicit GlobalMemChecker(arith::Analyzer *analyzer) : analyzer(analyzer) {}
GlobalMemChecker(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
void VisitExpr_(const BufferLoadNode *op) final { void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) { if (IsGlobalBuffer(op->buffer)) {
...@@ -116,9 +112,14 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -116,9 +112,14 @@ struct GlobalMemChecker : public StmtExprVisitor {
// We want to check if index < shape_dim can be proven. // We want to check if index < shape_dim can be proven.
// If analyzer->CanProve(index < shape_dim) returns false, // If analyzer->CanProve(index < shape_dim) returns false,
// it means we cannot prove the access is within bounds. // it means we cannot prove the access is within bounds.
PrimExpr cond = index < shape_dim; PrimExpr upper_bound_cond = index < shape_dim;
if (!analyzer->CanProve(cond)) { if (!analyzer_->CanProve(upper_bound_cond)) {
_conditions.push_back(cond); _conditions.push_back(upper_bound_cond);
}
// Check if index >= 0 can be proven.
PrimExpr lower_bound_cond = index >= 0;
if (!analyzer_->CanProve(lower_bound_cond)) {
_conditions.push_back(lower_bound_cond);
} }
} }
} }
...@@ -127,6 +128,7 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -127,6 +128,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
private: private:
Array<PrimExpr> _conditions; Array<PrimExpr> _conditions;
arith::Analyzer *analyzer_;
}; };
class SafeMemorysRewriter : public StmtExprMutator { class SafeMemorysRewriter : public StmtExprMutator {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment