Unverified Commit 689ee52b authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Refactored buffer detection logic in warp_specialized_rewriter.cc (#685)

- Renamed TMAFinder to ProducerBufferDetector and improved handling of CallNode and BufferLoadNode.
- This change aims to enhance code maintainability and performance by more accurately tracking producer buffer usage.
parent adcba275
......@@ -23,24 +23,45 @@ using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth };
class TMAFinder : public StmtExprVisitor {
class ProducerBufferDetector : public StmtExprVisitor {
public:
void clear() { has_tma_load_ = false; }
ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}
void clear() { has_producer_buffer_ = false; }
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
has_tma_load_ = true;
has_producer_buffer_ = true;
}
StmtExprVisitor::VisitExpr_(call);
}
bool has_tma_load_ = false;
void VisitExpr_(const BufferLoadNode *op) final {
if (cur_producer_buffers_.count(op->buffer.get())) {
has_producer_buffer_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}
bool has_producer_buffer_ = false;
std::unordered_set<const BufferNode *> cur_producer_buffers_;
};
class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
auto FindProducerusedBuffer(Stmt stmt) {
VisitStmt(stmt);
return used_in_producer_cond_;
producer_buffers_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) {
VisitStmt(stmt);
if (producer_buffers_ == last_producer_buffers_) {
break;
}
last_producer_buffers_ = producer_buffers_;
}
return producer_buffers_;
}
void InsertBuffer(const PrimExpr &expr) {
......@@ -48,44 +69,51 @@ public:
VarUseDefAnalyzer usage(Array<Var>{});
usage(expr);
for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first);
producer_buffers_.insert(buffer.first);
}
}
void VisitStmt_(const IfThenElseNode *op) final {
TMAFinder tma_finder;
tma_finder(op->then_case);
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
producer_buffer_detector(op->then_case);
if (op->else_case.defined()) {
tma_finder(op->else_case.value());
producer_buffer_detector(op->else_case.value());
}
if (tma_finder.has_tma_load_) {
if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->condition);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *op) final {
TMAFinder tma_finder;
tma_finder(op->body);
if (tma_finder.has_tma_load_) {
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
producer_buffer_detector(op->body);
if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->min);
InsertBuffer(op->extent);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (producer_buffers_.count(op->buffer.get())) {
InsertBuffer(op->value);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) {
used_in_producer_cond_.insert(buffer_load->buffer.get());
producer_buffers_.insert(buffer_load->buffer.get());
}
}
}
}
private:
std::unordered_set<const BufferNode *> used_in_producer_cond_;
std::unordered_set<const BufferNode *> producer_buffers_;
};
class WarpSpecializedRoleMarker : public StmtVisitor {
......@@ -95,7 +123,7 @@ public:
void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder;
used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
producer_buffers_ = finder.FindProducerusedBuffer(stmt);
}
Role GetRole(const StmtNode *stmt) const {
......@@ -123,7 +151,7 @@ public:
void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (used_in_producer_cond_.count(op->buffer.get())) {
if (producer_buffers_.count(op->buffer.get())) {
SetRole(op, Role::kBoth);
return;
}
......@@ -207,7 +235,7 @@ private:
std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
std::unordered_set<const BufferNode *> used_in_producer_cond_;
std::unordered_set<const BufferNode *> producer_buffers_;
};
static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
......@@ -1112,7 +1140,7 @@ private:
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) {
if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment