Commit e2d25ba8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Warp Specialize] Implicit Warp Specialize Programing Model (#605)

* [Enhancement] Improve memory access condition checks in GlobalMemChecker

- Updated the condition checks in the GlobalMemChecker to utilize symbolic bounds in the CanProve method, enhancing the accuracy of memory access validations.
- This change ensures that both upper and lower bound conditions are evaluated with improved proof strength, contributing to more robust memory access analysis.

* lintfix

* [Enhancement] Add legality checks for shared memory and global range in LowerBulkCopy

- Implemented checks to ensure that the shared memory range and global range are legal during the bulk copy operation.
- Added assertions to validate that the extents of global and shared ranges match, improving the robustness of memory access validation in the LowerBulkCopy function.

* [Refactor] Update barrier and clear operations in warp specialization examples

- Replaced `mbarrier_wait_parity` and `mbarrier_arrive` with `barrier_wait` and `barrier_arrive` for improved clarity and consistency in synchronization.
- Adjusted the order of `clear` operations for local fragments in `example_warp_specialize_gemm_copy_1_gemm_0` to enhance parallel execution efficiency.

* [Enhancement] Implement thread partial synchronization and improve shared memory allocation handling

- Added support for thread partial barrier synchronization in CUDA, allowing for more flexible thread management.
- Enhanced the `MergeSharedMemoryAllocations` function to accept alignment bytes, improving memory allocation efficiency based on target requirements.
- Updated the `Lower` methods in `Copy` and `Fill` classes to include conditional predicates for thread execution, ensuring better control over thread behavior.
- Refactored the `print` function to include warp group and warp IDs for more detailed debugging output.
- Improved the handling of dynamic shared memory allocations in the `LowerAndLegalize` function to align with target-specific requirements.

* [Enhancement] Add support for disabling TMA in Copy operations

- Introduced a new `disable_tma` parameter in the `Copy` class to control thread memory access behavior.
- Updated the `Lower` method to conditionally execute bulk copy operations based on the `disable_tma` flag.
- Enhanced the `copy` function to accept the `disable_tma` argument, allowing for more flexible memory copy operations.
- Improved handling of `coalesced_width` to ensure it defaults to -1 when not provided, enhancing robustness in memory operations.

* [Refactor] Clean up whitespace and formatting in multiple files

- Removed unnecessary blank lines and adjusted line breaks for improved code readability in `example_mla_decode.py`, `example_warp_specialize_gemm_copy_gemm_0_1.py`, `phase.py`, and `copy.py`.
- Ensured consistent formatting across functions to enhance maintainability and clarity of the codebase.

* [Enhancement] Refactor flash attention implementation for improved performance and configurability

- Split the shared memory allocations for query and key-value pairs to optimize memory usage.
- Introduced command-line arguments for batch size, number of heads, and dimensions, enhancing flexibility in running the example.
- Updated kernel execution parameters to improve thread management and synchronization.
- Enhanced the overall structure of the flash attention function for better readability and maintainability.

* fix

* Update layout inference in ParallelOp to account for thread bounds; remove debug print in OptimizeForTarget

* Refactor barrier handling and update example configurations

- Replaced commented-out barrier creation with new barrier allocation in GEMM example.
- Updated kernel configuration in warp specialization example to include async copy settings.
- Enhanced barrier management in the phase optimization process to improve synchronization handling.
- Introduced new barrier allocation function for better memory management in shared contexts.

* Refactor barrier handling in LowerAndLegalize and OptimizeForTarget

- Reintroduced barrier lowering in OptimizeForTarget to enhance synchronization.
- Removed commented-out barrier lowering in LowerAndLegalize for cleaner code.
- Added exit() call in OptimizeForTarget to halt execution after barrier lowering.

* Enhance CMake configuration and clean up example scripts

- Enabled compile command export in CMakeLists.txt for better build integration.
- Removed unnecessary print statement in the warp specialization example.
- Cleaned up commented-out code in GEMM example for improved readability.
- Updated barrier handling in shared memory allocation transformations for better synchronization.

* Refactor barrier handling in warp specialization examples

- Replaced commented-out mbarrier code with new barrier allocation using T.alloc_barrier for improved synchronization.
- Updated barrier wait and arrive calls to align with the new allocation method across multiple example scripts.
- Enhanced code readability by removing unnecessary comments and ensuring consistent barrier management.

* Update lower_shared_barrier.cc

* Update phase.py

* Update warp specialization example and Cython wrapper

- Removed commented-out pass configuration options in the warp specialization example for clarity.
- Added functionality to write the generated kernel source to a file named "kernel.cu".
- Enhanced Cython wrapper to support boolean type conversion for improved type handling.

* Add storage synchronization call in shared barrier transformation

- Introduced a new evaluation statement to call the TVM storage sync function with "shared" as an argument, enhancing synchronization in the shared barrier handling process.

* remove debug files

* Remove kernel source output to file in warp specialization example

* remove comments

* Refactor tensor handling and update test execution in TileLang

- Changed `Buffer` to `Tensor` in `customize.py` for better type consistency.
- Updated `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to use `tir.BufferLoad` instead of `BufferLoad`.
- Commented out the main testing function in `test_tilelang_language_reshape.py` and replaced it with a direct call to `run_reshape_smem` for streamlined testing.
- Removed unnecessary NVCC compiler flags in `libgen.py` to reduce verbosity.

* Update test_tilelang_language_reshape.py
parent 68989d80
......@@ -35,10 +35,11 @@ from .kernel import (
)
from .warpgroup import ws # noqa: F401
from .allocate import (
alloc_var, # noqa: F401
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
alloc_fragment, # noqa: F401
alloc_var, # noqa: F401
alloc_barrier, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401
......
......@@ -74,3 +74,15 @@ def alloc_var(dtype, scope="local.var"):
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
return T.alloc_buffer([1], dtype, scope=scope)
def alloc_barrier(arrive_count: int):
"""Allocate a barrier buffer.
Args:
arrive_count (int): The number of threads that need to arrive at the barrier
Returns:
T.Buffer: A TVM buffer object allocated as a barrier
"""
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier")
......@@ -187,12 +187,14 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union
Returns:
tir.Call: A handle to the barrier wait operation
"""
if isinstance(mbarrier, tir.Call):
if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier)
elif isinstance(mbarrier, tir.Buffer):
mbarrier = tir.BufferLoad(mbarrier, [0])
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
......@@ -203,12 +205,14 @@ def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
mbarrier: Optional[int, PrimExpr]
The memory barrier to arrive at
"""
if isinstance(mbarrier, tir.Call):
if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier)
elif isinstance(mbarrier, tir.Buffer):
mbarrier = tir.BufferLoad(mbarrier, [0])
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
return ptx_arrive_barrier(mbarrier)
......@@ -224,16 +228,17 @@ def mbarrier_expect_tx(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
def wait_wgmma(*args):
def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
Args:
*args: Variable arguments specifying which operations to wait for
id: int
The id of the WGMMA operation to wait for
Returns:
tir.Call: A handle to the WGMMA wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id)
def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None):
......
......@@ -84,6 +84,7 @@ def copy(
src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
disable_tma: bool = False,
):
"""Copy data between memory regions.
......@@ -130,10 +131,11 @@ def copy(
src = _to_region(src, "r")
dst = _to_region(dst, "w")
if coalesced_width is not None:
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width)
else:
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst)
if coalesced_width is None:
coalesced_width = -1 # PrimExpr can not be None
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width,
disable_tma)
def c2d_im2col(
......
......@@ -84,7 +84,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
return T.Buffer(shape, src.dtype, src.data)
return T.Tensor(shape, src.dtype, src.data)
def view(src: Buffer,
......@@ -104,4 +104,4 @@ def view(src: Buffer,
shape = src.shape
if dtype is None:
dtype = src.dtype
return T.Buffer(shape, dtype, src.data)
return T.Tensor(shape, dtype, src.data)
......@@ -133,7 +133,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
buffer[coords])
def print(obj: Any, msg: str = "") -> tir.PrimExpr:
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
......@@ -143,6 +143,9 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
Parameters:
obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr.
msg (str): An optional message to include in the print statement.
warp_group_id (int): The warp group id to print.
warp_id (int): The warp id to print.
print thread will be warp_group_id * warp_group_size + warp_id.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
......@@ -154,6 +157,9 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
# Buffers must be printed in just one thread to avoid duplicate outputs.
# Retrieve the thread bindings for thread x, y, and z.
tx, ty, tz = get_thread_bindings()
warp_group_size = 128
warp_size = 32
main_lane = warp_group_id * warp_group_size + warp_id * warp_size
# Flatten the buffer for consistent printing. This assumes a 1D flattened buffer.
buffer = obj
......@@ -173,7 +179,7 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0)
condition = (tx == main_lane and ty == 0 and tz == 0)
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_fragment_buffer_with_condition(condition, buffer, elems, msg)
......@@ -184,7 +190,7 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0)
condition = (tx == main_lane and ty == 0 and tz == 0)
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_shared_buffer_with_condition(condition, buffer, elems, msg)
......
......@@ -333,7 +333,7 @@ def EliminateStorageSyncForMBarrier():
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_bytes: int = 16):
"""MergeSharedMemoryAllocations
Returns
......@@ -341,7 +341,8 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge,
align_bytes) # type: ignore
def LowerL2Persistent():
......@@ -368,3 +369,9 @@ def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16):
-------
"""
return _ffi_api.AlignDynamicSharedMemoryAllocations(align_bytes) # type: ignore
def LowerSharedBarrier():
"""LowerSharedBarrier
"""
return _ffi_api.LowerSharedBarrier() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment