"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c6c361d80ada8117e926bd24f71f50bb5da9f0b3"
Commit 2abd6ab7 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Bugfix] Add TMA and Producer Buffer Analysis in Warp Specialized Rewriter (#269)

- Introduced TMAFinder and ProducerUsedBufferFinder classes to analyze TMA loads and identify buffers used in producer conditions.
- Enhanced WarpSpecializedRoleMarker to prepare and utilize the identified buffers during role marking.
- Updated VisitStmt methods to incorporate new analysis logic for IfThenElse and For nodes, improving the handling of TMA loads in the warp specialization process.
parent 47caf219
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Warp specialized Pipeline for cuda GPU (sm90+) * \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/ */
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -37,11 +38,73 @@ using namespace tir; ...@@ -37,11 +38,73 @@ using namespace tir;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
class TMAFinder : public StmtExprVisitor {
public:
void clear() { has_tma_load_ = false; }
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
has_tma_load_ = true;
}
}
bool has_tma_load_ = false;
};
class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
auto FindProducerusedBuffer(Stmt stmt) {
VisitStmt(stmt);
return used_in_producer_cond_;
}
void InsertBuffer(const PrimExpr &expr) {
// Find the buffer that is used in the condition
VarUseDefAnalyzer usage(Array<Var>{});
usage(expr);
for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first);
}
for (const auto &buffer : used_in_producer_cond_) {
}
}
void VisitStmt_(const IfThenElseNode *op) final {
TMAFinder tma_finder;
tma_finder(op->then_case);
if (op->else_case.defined()) {
tma_finder(op->else_case.value());
}
if (tma_finder.has_tma_load_) {
InsertBuffer(op->condition);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *op) final {
TMAFinder tma_finder;
tma_finder(op->body);
if (tma_finder.has_tma_load_) {
InsertBuffer(op->min);
InsertBuffer(op->extent);
}
StmtExprVisitor::VisitStmt_(op);
}
private:
std::unordered_set<const BufferNode *> used_in_producer_cond_;
};
class WarpSpecializedRoleMarker : public StmtVisitor { class WarpSpecializedRoleMarker : public StmtVisitor {
public: public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer) WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(buffer_data_to_buffer) {}
void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder;
used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
}
Role GetRole(const StmtNode *stmt) const { Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt); auto it = map_.find(stmt);
ICHECK(it != map_.end()); ICHECK(it != map_.end());
...@@ -65,6 +128,10 @@ public: ...@@ -65,6 +128,10 @@ public:
void VisitStmt_(const BufferStoreNode *op) final { void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store = bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (used_in_producer_cond_.count(op->buffer.get())) {
SetRole(op, Role::kBoth);
return;
}
if (!is_shared_store) { if (!is_shared_store) {
SetRole(op, Role::kConsumer); SetRole(op, Role::kConsumer);
return; return;
...@@ -136,6 +203,7 @@ private: ...@@ -136,6 +203,7 @@ private:
std::unordered_map<const StmtNode *, Role> map_; std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
bool has_bulk_copy_ = false; bool has_bulk_copy_ = false;
std::unordered_set<const BufferNode *> used_in_producer_cond_;
}; };
static PrimExpr makeGetBarrier(PrimExpr barrier_id) { static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
...@@ -1073,6 +1141,7 @@ private: ...@@ -1073,6 +1141,7 @@ private:
Block block = block_realize->block; Block block = block_realize->block;
WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
marker.Prepare(block);
marker(block); marker(block);
if (!marker.HasProducer()) { if (!marker.HasProducer()) {
// Cannot detect any producer here, directly return. // Cannot detect any producer here, directly return.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment