Commit 41c51d07 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Reorganize Thread Synchronization Steps to make sure global...

[Refactor] Reorganize Thread Synchronization Steps to make sure global synchronization can be correctly lowered (#521)

* [Refactor] Reorganize Thread Synchronization Steps in OptimizeForTarget Function

* Removed redundant thread synchronization steps for "global" and "shared" memory, streamlining the optimization process.
* Reintroduced necessary synchronization for "shared" and "shared.dyn" after the injection of PTX async copy, ensuring correct memory access patterns.
* Enhanced overall clarity and maintainability of the OptimizeForTarget function by restructuring the order of operations.

* [Refactor] Reorder Thread Synchronization and PTX Async Copy in OptimizeForTarget Function

* Removed redundant global thread synchronization step and adjusted the order of operations for shared memory synchronization.
* Ensured that the PTX async copy injection occurs after the global thread sync, improving memory access validity.
* Enhanced clarity and maintainability of the OptimizeForTarget function by restructuring synchronization steps.
parent 46798f25
......@@ -133,12 +133,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
......@@ -150,6 +144,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment