Unverified Commit 689ee52b authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Refactored buffer detection logic in warp_specialized_rewriter.cc (#685)

- Renamed TMAFinder to ProducerBufferDetector and improved handling of CallNode and BufferLoadNode.
- This change aims to enhance code maintainability and performance by more accurately tracking producer buffer usage.
parent adcba275
...@@ -23,24 +23,45 @@ using arith::IRVisitorWithAnalyzer; ...@@ -23,24 +23,45 @@ using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
class TMAFinder : public StmtExprVisitor { class ProducerBufferDetector : public StmtExprVisitor {
public: public:
void clear() { has_tma_load_ = false; } ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}
void clear() { has_producer_buffer_ = false; }
void VisitExpr_(const CallNode *call) final { void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
has_tma_load_ = true; has_producer_buffer_ = true;
} }
StmtExprVisitor::VisitExpr_(call);
} }
bool has_tma_load_ = false; void VisitExpr_(const BufferLoadNode *op) final {
if (cur_producer_buffers_.count(op->buffer.get())) {
has_producer_buffer_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}
bool has_producer_buffer_ = false;
std::unordered_set<const BufferNode *> cur_producer_buffers_;
}; };
class ProducerUsedBufferFinder : public StmtExprVisitor { class ProducerUsedBufferFinder : public StmtExprVisitor {
public: public:
auto FindProducerusedBuffer(Stmt stmt) { auto FindProducerusedBuffer(Stmt stmt) {
producer_buffers_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) {
VisitStmt(stmt); VisitStmt(stmt);
return used_in_producer_cond_; if (producer_buffers_ == last_producer_buffers_) {
break;
}
last_producer_buffers_ = producer_buffers_;
}
return producer_buffers_;
} }
void InsertBuffer(const PrimExpr &expr) { void InsertBuffer(const PrimExpr &expr) {
...@@ -48,44 +69,51 @@ public: ...@@ -48,44 +69,51 @@ public:
VarUseDefAnalyzer usage(Array<Var>{}); VarUseDefAnalyzer usage(Array<Var>{});
usage(expr); usage(expr);
for (const auto &buffer : usage.buffer_use_count_) { for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first); producer_buffers_.insert(buffer.first);
} }
} }
void VisitStmt_(const IfThenElseNode *op) final { void VisitStmt_(const IfThenElseNode *op) final {
TMAFinder tma_finder; ProducerBufferDetector producer_buffer_detector(producer_buffers_);
tma_finder(op->then_case); producer_buffer_detector(op->then_case);
if (op->else_case.defined()) { if (op->else_case.defined()) {
tma_finder(op->else_case.value()); producer_buffer_detector(op->else_case.value());
} }
if (tma_finder.has_tma_load_) { if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->condition); InsertBuffer(op->condition);
} }
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
TMAFinder tma_finder; ProducerBufferDetector producer_buffer_detector(producer_buffers_);
tma_finder(op->body); producer_buffer_detector(op->body);
if (tma_finder.has_tma_load_) { if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->min); InsertBuffer(op->min);
InsertBuffer(op->extent); InsertBuffer(op->extent);
} }
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const BufferStoreNode *op) final {
if (producer_buffers_.count(op->buffer.get())) {
InsertBuffer(op->value);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) { for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) { if (auto buffer_load = arg.as<BufferLoadNode>()) {
used_in_producer_cond_.insert(buffer_load->buffer.get()); producer_buffers_.insert(buffer_load->buffer.get());
} }
} }
} }
} }
private: private:
std::unordered_set<const BufferNode *> used_in_producer_cond_; std::unordered_set<const BufferNode *> producer_buffers_;
}; };
class WarpSpecializedRoleMarker : public StmtVisitor { class WarpSpecializedRoleMarker : public StmtVisitor {
...@@ -95,7 +123,7 @@ public: ...@@ -95,7 +123,7 @@ public:
void Prepare(const Stmt &stmt) { void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder; ProducerUsedBufferFinder finder;
used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt); producer_buffers_ = finder.FindProducerusedBuffer(stmt);
} }
Role GetRole(const StmtNode *stmt) const { Role GetRole(const StmtNode *stmt) const {
...@@ -123,7 +151,7 @@ public: ...@@ -123,7 +151,7 @@ public:
void VisitStmt_(const BufferStoreNode *op) final { void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store = bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (used_in_producer_cond_.count(op->buffer.get())) { if (producer_buffers_.count(op->buffer.get())) {
SetRole(op, Role::kBoth); SetRole(op, Role::kBoth);
return; return;
} }
...@@ -207,7 +235,7 @@ private: ...@@ -207,7 +235,7 @@ private:
std::unordered_map<const StmtNode *, Role> map_; std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
bool has_bulk_copy_ = false; bool has_bulk_copy_ = false;
std::unordered_set<const BufferNode *> used_in_producer_cond_; std::unordered_set<const BufferNode *> producer_buffers_;
}; };
static PrimExpr makeGetBarrier(PrimExpr barrier_id) { static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
...@@ -1112,7 +1140,7 @@ private: ...@@ -1112,7 +1140,7 @@ private:
auto inc_reg_stmt = Evaluate(0); auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0); auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) { if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1})); {inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment