Commit 0fdefe2b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Enhance MergeSharedMemoryAllocations Pass for Improved Liveness...

[Refactor] Enhance MergeSharedMemoryAllocations Pass for Improved Liveness Analysis and Scope Management (#508)

* Introduced a new StmtAttr structure to track the scope level of statements, enhancing the liveness analysis process.
* Updated the UpdateStmtAttr function to manage statement attributes effectively during memory allocation visits.
* Modified the VisitStmt_ methods to utilize the new scope level tracking, ensuring accurate memory access patterns.
* Refactored the LivenessAnalysis and PlanMemory functions to incorporate statement attributes, improving the handling of gen and kill points in memory management.
* Added a new helper function allow_warp_specialized in phase.py to conditionally enable warp specialization based on pass context and target, addressing potential bugs in the MergeSharedMemoryAllocations pass.
* Enhanced the OptimizeForTarget function to conditionally apply the MergeSharedMemoryAllocations pass based on warp specialization settings, improving robustness in memory allocation strategies.
parent f23c4d30
...@@ -119,6 +119,19 @@ public: ...@@ -119,6 +119,19 @@ public:
const AllocateNode *alloc{nullptr}; const AllocateNode *alloc{nullptr};
}; };
struct StmtAttr {
// the level in the scope stack
size_t level{0};
};
void UpdateStmtAttr(const Object *stmt, size_t level) {
if (stmt_attrs_.find(stmt) == stmt_attrs_.end()) {
stmt_attrs_[stmt] = StmtAttr{level};
} else {
stmt_attrs_[stmt].level = level;
}
}
void VisitStmt_(const AllocateNode *op) final { void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size(); size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get(); const VarNode *buf = op->buffer_var.get();
...@@ -137,13 +150,14 @@ public: ...@@ -137,13 +150,14 @@ public:
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
} }
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (e.touched.size() != 0) {
e.stmt = op; e.stmt = op;
UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
} }
...@@ -156,6 +170,7 @@ public: ...@@ -156,6 +170,7 @@ public:
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (e.touched.size() != 0) {
e.stmt = op; e.stmt = op;
UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
} }
...@@ -169,7 +184,7 @@ public: ...@@ -169,7 +184,7 @@ public:
ICHECK_LT(it->second.level, scope_.size()) ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
} }
} }
} }
...@@ -180,7 +195,7 @@ public: ...@@ -180,7 +195,7 @@ public:
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
} }
} }
} }
...@@ -189,6 +204,7 @@ public: ...@@ -189,6 +204,7 @@ public:
scope_.push_back(StmtEntry()); scope_.push_back(StmtEntry());
StmtEntry e; StmtEntry e;
e.stmt = op; e.stmt = op;
UpdateStmtAttr(op, scope_level_);
int64_t begin_index = static_cast<int64_t>(linear_seq_.size()); int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope. // before scope.
linear_seq_.push_back(e); linear_seq_.push_back(e);
...@@ -226,7 +242,15 @@ public: ...@@ -226,7 +242,15 @@ public:
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); } void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
void VisitStmt_(const ForNode *op) final { VisitNewScope(op); } void VisitStmt_(const ForNode *op) final {
if (op->body->IsInstance<SeqStmtNode>()) {
scope_level_++;
VisitNewScope(op);
scope_level_--;
} else {
VisitNewScope(op);
}
}
void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); } void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); }
...@@ -236,6 +260,8 @@ public: ...@@ -236,6 +260,8 @@ public:
std::vector<StmtEntry> linear_seq_; std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer // The storage scope of each buffer
std::unordered_map<const VarNode *, AllocEntry> alloc_info_; std::unordered_map<const VarNode *, AllocEntry> alloc_info_;
// The attribute of each statement
std::unordered_map<const Object *, StmtAttr> stmt_attrs_;
private: private:
// Wrapper function to determine if the shared memory allocation for a // Wrapper function to determine if the shared memory allocation for a
...@@ -251,6 +277,8 @@ private: ...@@ -251,6 +277,8 @@ private:
bool in_thread_env_{false}; bool in_thread_env_{false};
// The scope stack. // The scope stack.
std::vector<StmtEntry> scope_; std::vector<StmtEntry> scope_;
// The size of the scope.
size_t scope_level_{0};
}; };
/*! /*!
...@@ -279,8 +307,8 @@ public: ...@@ -279,8 +307,8 @@ public:
bool verbose = false) { bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose); SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose);
finder(stmt); finder(stmt);
this->LivenessAnalysis(finder.linear_seq_); this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_); this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
} }
private: private:
...@@ -491,6 +519,7 @@ private: ...@@ -491,6 +519,7 @@ private:
} }
using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
struct StorageEntry { struct StorageEntry {
// The constant size of the buffer in bits, only used if it is constant // The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0}; uint64_t const_nbits{0};
...@@ -515,7 +544,9 @@ private: ...@@ -515,7 +544,9 @@ private:
* \brief Liveness analysis to find gen and kill point of each variable. * \brief Liveness analysis to find gen and kill point of each variable.
* \param seq the linear pattern of storage access * \param seq the linear pattern of storage access
*/ */
void LivenessAnalysis(const std::vector<StmtEntry> &seq) { void LivenessAnalysis(
const std::vector<StmtEntry> &seq,
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
// find kill point, do a reverse linear scan. // find kill point, do a reverse linear scan.
std::unordered_set<const VarNode *> touched; std::unordered_set<const VarNode *> touched;
for (size_t i = seq.size(); i != 0; --i) { for (size_t i = seq.size(); i != 0; --i) {
...@@ -543,17 +574,174 @@ private: ...@@ -543,17 +574,174 @@ private:
} }
if (verbose_) { if (verbose_) {
LOG(DEBUG) << "Liveness Analysis Results for " std::vector<const Object *> stmt_keys;
for (const auto &stmt_entry : seq) {
auto stmt = stmt_entry.stmt;
if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) ==
stmt_keys.end()) {
stmt_keys.push_back(stmt);
}
}
LOG(DEBUG) << "Before reorder kill points, Liveness Analysis Results for "
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
for (const auto &pair : event_map_) { for (const auto &stmt_key : stmt_keys) {
const Object *stmt_obj = pair.first; auto it = event_map_.find(stmt_key);
const EventEntry &entry = pair.second; if (it == event_map_.end())
continue;
const EventEntry &entry = it->second;
if (entry.gen.empty() && entry.kill.empty())
continue;
ICHECK(stmt_attrs.count(stmt_key))
<< "stmt_key = " << stmt_key->GetTypeKey();
auto level = stmt_attrs.at(stmt_key).level;
LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey()
<< " (scope_level: " << level << ")";
std::stringstream gen_vars_ss;
bool x_generated = false;
for (const VarNode *var : entry.gen) {
gen_vars_ss << var->name_hint << " ";
if (var->name_hint == "x") {
x_generated = true;
}
}
if (!entry.gen.empty()) {
std::string gen_log_msg = " GEN: " + gen_vars_ss.str();
if (x_generated) {
gen_log_msg += " <-- Buffer 'x' generated";
}
LOG(DEBUG) << gen_log_msg;
}
std::stringstream kill_vars_ss;
bool x_killed = false;
for (const VarNode *var : entry.kill) {
kill_vars_ss << var->name_hint << " ";
if (var->name_hint == "x") {
x_killed = true;
}
}
if (!entry.kill.empty()) {
std::string kill_log_msg = " KILL: " + kill_vars_ss.str();
if (x_killed) {
kill_log_msg += " <-- Buffer 'x' killed";
}
LOG(DEBUG) << kill_log_msg;
}
}
LOG(DEBUG) << "End of Liveness Analysis Results.";
}
// Reorder kill points:
// For each buffer, if its kill statement is at a deeper scope level than
// its gen statement, we need to move the kill point to the end of the gen
// statement's scope level. This ensures proper memory deallocation at the
// right scope boundary.
std::vector<StmtEntry> gen_kill_seq;
for (const auto &stmt_entry : seq) {
// if has gen and kill, add to gen_kill_seq
if (event_map_[stmt_entry.stmt].gen.size() > 0 ||
event_map_[stmt_entry.stmt].kill.size() > 0) {
gen_kill_seq.push_back(stmt_entry);
}
}
if (entry.gen.empty() && entry.kill.empty()) { for (auto &event_pair : event_map_) {
continue; // Skip statements with no gen/kill events for brevity const Object *stmt = event_pair.first;
EventEntry &event = event_pair.second;
// Skip if no kill points to process
if (event.kill.empty())
continue;
// Get scope level of current statement
ICHECK(stmt_attrs.count(stmt));
int kill_level = stmt_attrs.at(stmt).level;
std::unordered_set<const VarNode *> visited_buffers;
// For each killed buffer, find its gen statement and check scope levels
for (auto it = event.kill.begin(); it != event.kill.end();) {
const VarNode *buffer = *it;
bool found_gen = false;
int gen_level = 0;
// Find the gen statement for this buffer
for (const auto &gen_pair : event_map_) {
const auto &gen_event = gen_pair.second;
if (std::find(gen_event.gen.begin(), gen_event.gen.end(), buffer) !=
gen_event.gen.end()) {
found_gen = true;
gen_level = stmt_attrs.at(gen_pair.first).level;
break;
}
}
if (found_gen && kill_level > gen_level) {
if (visited_buffers.count(buffer)) {
++it;
continue;
}
// Need to move kill point - remove from current event
it = event.kill.erase(it);
// Find the last statement at gen_level and add kill point there
// Find the last statement at gen_level in the sequence
const Object *last_stmt_at_level = nullptr;
auto stmt_it = gen_kill_seq.begin();
for (; stmt_it != gen_kill_seq.end(); ++stmt_it) {
if (stmt_it->stmt == stmt) {
break;
}
}
// start from current statement and find the last statement at
// gen_level
for (; stmt_it != gen_kill_seq.end(); ++stmt_it) {
// Check if next statement has different level
auto next_it = stmt_it + 1;
if (next_it == gen_kill_seq.end() ||
stmt_attrs.at(next_it->stmt).level == gen_level) {
last_stmt_at_level = stmt_it->stmt;
break;
}
}
if (last_stmt_at_level) {
event_map_[last_stmt_at_level].kill.push_back(buffer);
visited_buffers.insert(buffer);
}
} else {
++it;
} }
}
}
LOG(DEBUG) << " Statement: " << stmt_obj->GetTypeKey(); std::vector<const Object *> stmt_keys;
for (const auto &stmt_entry : seq) {
auto stmt = stmt_entry.stmt;
if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) ==
stmt_keys.end()) {
stmt_keys.push_back(stmt);
}
}
if (verbose_) {
LOG(DEBUG) << "Liveness Analysis Results for "
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
for (const auto &stmt_key : stmt_keys) {
auto it = event_map_.find(stmt_key);
if (it == event_map_.end())
continue;
const EventEntry &entry = it->second;
if (entry.gen.empty() && entry.kill.empty())
continue;
ICHECK(stmt_attrs.count(stmt_key))
<< "stmt_key = " << stmt_key->GetTypeKey();
auto level = stmt_attrs.at(stmt_key).level;
LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey()
<< " (scope_level: " << level << ")";
std::stringstream gen_vars_ss; std::stringstream gen_vars_ss;
bool x_generated = false; bool x_generated = false;
...@@ -596,7 +784,9 @@ private: ...@@ -596,7 +784,9 @@ private:
* \param seq the linear pattern of storage access * \param seq the linear pattern of storage access
* \param alloc_info * \param alloc_info
*/ */
void PlanMemory(const std::vector<StmtEntry> &seq) { void
PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
std::unordered_set<const VarNode *> inplace_flag; std::unordered_set<const VarNode *> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
......
...@@ -6,6 +6,19 @@ from tilelang.contrib.nvcc import have_tma ...@@ -6,6 +6,19 @@ from tilelang.contrib.nvcc import have_tma
from typing import Optional from typing import Optional
def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if not is_cuda_target(target):
return False
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not disable_warp_specialized
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
# avoid circular import # avoid circular import
...@@ -16,9 +29,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -16,9 +29,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
if not is_cuda_target(target) or not have_tma(target): if not is_cuda_target(target) or not have_tma(target):
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not (disable_tma_lower and disable_warp_specialized)
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Optional[Target] = None) -> bool:
...@@ -128,7 +139,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -128,7 +139,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment