Commit 3ca5a4ba authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE`...

[Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` to enable aggressive shared memory reuse (#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.
parent a664c998
......@@ -23,6 +23,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
......
......@@ -25,6 +25,8 @@ static constexpr const char *kDisableSafeMemoryLegalize =
static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";
/*!
* \brief Whether to disable dynamic tail split
......
......@@ -95,9 +95,11 @@ public:
//
class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
public:
explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true,
bool verbose = false)
: is_dynamic_(is_dynamic), verbose_(verbose) {}
explicit SharedMemLinearAccessPatternFinder(
bool is_dynamic = true, bool enable_aggressive_merge = false,
bool verbose = false)
: is_dynamic_(is_dynamic),
enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {}
/*! \brief record the touch list of statement. */
struct StmtEntry {
// The statement
......@@ -151,9 +153,15 @@ public:
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
}
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
......@@ -185,7 +193,12 @@ public:
ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
}
}
}
......@@ -196,7 +209,12 @@ public:
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
}
}
}
......@@ -284,6 +302,8 @@ private:
}
// Whether do dyanmic analysis.
bool is_dynamic_{true};
// Whether do aggressive merge.
bool enable_aggressive_merge_{false};
// Whether do verbose logging.
bool verbose_{false};
// Whether already in thread env.
......@@ -317,8 +337,9 @@ public:
* \param stmt the statement
*/
void PlanReuse(const Stmt &stmt, bool is_dynamic = true,
bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose);
bool enable_aggressive_merge = false, bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic,
enable_aggressive_merge, verbose);
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
......@@ -956,6 +977,7 @@ private:
}
// Wheather enable dyanmic analysis.
bool is_dynamic_{true};
// Whether enable verbose logging.
bool verbose_{false};
// The var for the merged buffer
......@@ -985,18 +1007,19 @@ private:
};
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
bool enable_aggressive_merge,
bool verbose = false) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose);
rewriter.PlanReuse(stmt);
rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
verbose);
rewriter.PlanReuse(stmt, false);
rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
return stmt;
......@@ -1006,17 +1029,18 @@ using namespace tir::transform;
namespace transform {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m,
PassContext ctx) {
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations =
ctx->GetConfig<Bool>(kDebugMergeSharedMemoryAllocations, Bool(false))
.value();
auto *n = f.CopyOnWrite();
n->body =
tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem,
debug_merge_shared_memory_allocations);
n->body = tl::MergeSharedMemoryAllocations(
std::move(n->body), merge_static_smem, enable_aggressive_merge,
debug_merge_shared_memory_allocations);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",
......
......@@ -48,6 +48,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None)
return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
enable_aggressive_merge = False
return enable_aggressive_merge
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod)
......@@ -149,7 +163,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
......@@ -333,7 +333,7 @@ def EliminateStorageSyncForMBarrier():
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
def MergeSharedMemoryAllocations():
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
"""MergeSharedMemoryAllocations
Returns
......@@ -341,7 +341,7 @@ def MergeSharedMemoryAllocations():
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore
def LowerL2Persistent():
......
......@@ -30,6 +30,9 @@ class PassConfigKey(str, Enum):
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False"""
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
"""Enable aggressive merge of shared memory allocations. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment