"docs/vscode:/vscode.git/clone" did not exist on "de60a3fb93957dce6b242299b5d163f02ef7f383"
Commit b060c9f7 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Align dynamic shared memory allocations in phase.py (#644)

- Added a comment to clarify the alignment of dynamic shared memory allocations in the `OptimizeForTarget` function.
- Refactored the handling of shared memory allocation merging and synchronization to streamline the process, ensuring consistent behavior regardless of the aggressive merge flag.
- Improved code clarity by removing redundant conditional checks related to synchronization and memory allocation.
parent 6c0a5841
...@@ -80,6 +80,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -80,6 +80,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Align dynamic shared memory allocations
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
# use an enhanced pass to simplify the dynamic symbolics # use an enhanced pass to simplify the dynamic symbolics
...@@ -167,19 +168,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -167,19 +168,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes # Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes
# For other devices, we align to 16 bytes # For other devices, we align to 16 bytes
smem_align_bytes = 1024 if have_tma(target) else 16 smem_align_bytes = 1024 if have_tma(target) else 16
if enable_aggressive_merge:
# Workaround, wait for a element wise synchronization pass # Workaround, wait for a element wise synchronization pass
mod = tilelang.transform.MergeSharedMemoryAllocations( mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)( enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)(
mod) mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
else:
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)(
mod)
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load # as ptx async copy won't be recognized as a valid buffer load
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment