Unverified Commit c1eef511 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Pipeline] Skip condition expression analysis for global reading (#713)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling

- Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes.
- Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management.
- Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body.
- Removed obsolete code and improved overall code clarity and maintainability.

* lint fix

* Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls

- Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves.
- Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations.

* test fix

* Enhance global read detection in pipeline planning

- Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses.
- Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis.
- Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code.
parent 49d5d80e
......@@ -6,6 +6,7 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ir/expr.h"
namespace tvm {
namespace tl {
......@@ -81,7 +82,11 @@ private:
auto load_region = BufferRegion(load_buffer, region);
reads_.push_back(load_region);
if (op->buffer.scope() == "global") {
if (op->buffer.scope() == "global" && !within_condition_expr_) {
// skip condition expr of if_then_else node
// shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i])
// is not a global read shared[i] = T.if_then_else(global[i] < n,
// global_a[i], global_b[i]) is a global read
is_global_read_ = true;
}
}
......@@ -103,11 +108,30 @@ private:
// because we only care about the buffer itself instead of indices
reads_.push_back(buffer_region);
}
} else if (op->op.same_as(builtin::if_then_else())) {
within_condition_expr_ = true;
this->VisitExpr(op->args[0]);
within_condition_expr_ = false;
for (auto i = 1; i < op->args.size(); i++) {
this->VisitExpr(op->args[i]);
}
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
void VisitStmt_(const IfThenElseNode *op) final {
within_condition_expr_ = true;
this->VisitExpr(op->condition);
within_condition_expr_ = false;
this->VisitStmt(op->then_case);
if (op->else_case.defined()) {
within_condition_expr_ = true;
this->VisitStmt(op->else_case.value());
within_condition_expr_ = false;
}
}
private:
Map<Var, Buffer> buffer_data_to_buffer_;
Array<BufferRegion> reads_;
......@@ -115,6 +139,7 @@ private:
bool is_global_read_ = false;
bool under_buffer_store_ = false;
bool is_global_copy_pattern_ = false;
bool within_condition_expr_ = false;
};
class PipelinePlanner : public StmtExprMutator {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment