Commit 3ca5a4ba authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE`...

[Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` to enable aggressive shared memory reuse (#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.
parent a664c998
...@@ -23,6 +23,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); ...@@ -23,6 +23,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \ const Op &OpName() { \
......
...@@ -25,6 +25,8 @@ static constexpr const char *kDisableSafeMemoryLegalize = ...@@ -25,6 +25,8 @@ static constexpr const char *kDisableSafeMemoryLegalize =
static constexpr const char *kDisableWarpSpecialized = static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized"; "tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";
/*! /*!
* \brief Whether to disable dynamic tail split * \brief Whether to disable dynamic tail split
......
...@@ -95,9 +95,11 @@ public: ...@@ -95,9 +95,11 @@ public:
// //
class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
public: public:
explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true, explicit SharedMemLinearAccessPatternFinder(
bool verbose = false) bool is_dynamic = true, bool enable_aggressive_merge = false,
: is_dynamic_(is_dynamic), verbose_(verbose) {} bool verbose = false)
: is_dynamic_(is_dynamic),
enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {}
/*! \brief record the touch list of statement. */ /*! \brief record the touch list of statement. */
struct StmtEntry { struct StmtEntry {
// The statement // The statement
...@@ -151,9 +153,15 @@ public: ...@@ -151,9 +153,15 @@ public:
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse // set into scope_.size() - 1 for aggressive memory reuse
scope_[it->second.level].touched.push_back(buf); auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
} }
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (e.touched.size() != 0) {
...@@ -185,7 +193,12 @@ public: ...@@ -185,7 +193,12 @@ public:
ICHECK_LT(it->second.level, scope_.size()) ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf); auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
} }
} }
} }
...@@ -196,7 +209,12 @@ public: ...@@ -196,7 +209,12 @@ public:
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf); auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
} }
} }
} }
...@@ -284,6 +302,8 @@ private: ...@@ -284,6 +302,8 @@ private:
} }
// Whether do dyanmic analysis. // Whether do dyanmic analysis.
bool is_dynamic_{true}; bool is_dynamic_{true};
// Whether do aggressive merge.
bool enable_aggressive_merge_{false};
// Whether do verbose logging. // Whether do verbose logging.
bool verbose_{false}; bool verbose_{false};
// Whether already in thread env. // Whether already in thread env.
...@@ -317,8 +337,9 @@ public: ...@@ -317,8 +337,9 @@ public:
* \param stmt the statement * \param stmt the statement
*/ */
void PlanReuse(const Stmt &stmt, bool is_dynamic = true, void PlanReuse(const Stmt &stmt, bool is_dynamic = true,
bool verbose = false) { bool enable_aggressive_merge = false, bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose); SharedMemLinearAccessPatternFinder finder(is_dynamic,
enable_aggressive_merge, verbose);
finder(stmt); finder(stmt);
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_); this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_); this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
...@@ -956,6 +977,7 @@ private: ...@@ -956,6 +977,7 @@ private:
} }
// Wheather enable dyanmic analysis. // Wheather enable dyanmic analysis.
bool is_dynamic_{true}; bool is_dynamic_{true};
// Whether enable verbose logging. // Whether enable verbose logging.
bool verbose_{false}; bool verbose_{false};
// The var for the merged buffer // The var for the merged buffer
...@@ -985,18 +1007,19 @@ private: ...@@ -985,18 +1007,19 @@ private:
}; };
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
bool enable_aggressive_merge,
bool verbose = false) { bool verbose = false) {
AllocateCollector collector; AllocateCollector collector;
collector(stmt); collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) { if (collector.dyn_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose); SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose);
rewriter.PlanReuse(stmt); rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
stmt = rewriter(std::move(stmt)); stmt = rewriter(std::move(stmt));
} }
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false, SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
verbose); verbose);
rewriter.PlanReuse(stmt, false); rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
stmt = rewriter(std::move(stmt)); stmt = rewriter(std::move(stmt));
} }
return stmt; return stmt;
...@@ -1006,17 +1029,18 @@ using namespace tir::transform; ...@@ -1006,17 +1029,18 @@ using namespace tir::transform;
namespace transform { namespace transform {
Pass MergeSharedMemoryAllocations() { Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m,
PassContext ctx) {
bool merge_static_smem = bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value(); ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations = bool debug_merge_shared_memory_allocations =
ctx->GetConfig<Bool>(kDebugMergeSharedMemoryAllocations, Bool(false)) ctx->GetConfig<Bool>(kDebugMergeSharedMemoryAllocations, Bool(false))
.value(); .value();
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = n->body = tl::MergeSharedMemoryAllocations(
tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem, std::move(n->body), merge_static_smem, enable_aggressive_merge,
debug_merge_shared_memory_allocations); debug_merge_shared_memory_allocations);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations", return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",
......
...@@ -48,6 +48,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) ...@@ -48,6 +48,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None)
return enable_global_thread_sync return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
enable_aggressive_merge = False
return enable_aggressive_merge
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
...@@ -149,7 +163,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -149,7 +163,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod) mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
...@@ -333,7 +333,7 @@ def EliminateStorageSyncForMBarrier(): ...@@ -333,7 +333,7 @@ def EliminateStorageSyncForMBarrier():
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
def MergeSharedMemoryAllocations(): def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
"""MergeSharedMemoryAllocations """MergeSharedMemoryAllocations
Returns Returns
...@@ -341,7 +341,7 @@ def MergeSharedMemoryAllocations(): ...@@ -341,7 +341,7 @@ def MergeSharedMemoryAllocations():
fpass : tvm.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore
def LowerL2Persistent(): def LowerL2Persistent():
......
...@@ -30,6 +30,9 @@ class PassConfigKey(str, Enum): ...@@ -30,6 +30,9 @@ class PassConfigKey(str, Enum):
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False""" """Enable debug information for merge shared memory allocations. Default: False"""
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
"""Enable aggressive merge of shared memory allocations. Default: False"""
# TIR related configs # TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment