Unverified Commit 8eab7755 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Reducer] Introduce `alloc_reducer` to separate inter and intra warp reduction (#757)



* [Enhancement] Introduce finalize_reducer operator and layout reducer support

- Added `FinalizeReducer` operator to handle reduction finalization in the TileLang framework, allowing for efficient reduction operations.
- Implemented layout inference for local.reducer buffers, enhancing the handling of layout mappings and reducing complexity in buffer management.
- Updated `setup.py` to include logging for build directory paths, improving build process visibility.
- Enhanced atomic operations with new functions for atomic max, min, load, and store, providing more robust atomicity control in memory operations.
- Refactored parallel loop handling to incorporate reducer information, ensuring proper management of reduction operations in parallel contexts.
- Cleaned up test cases by removing unnecessary cache disabling and optimizing test parameters for better performance.

* Refactor code formatting and improve readability in multiple files

- Cleaned up whitespace in `setup.py` to enhance logging clarity.
- Reformatted `AtomicMax` and `AtomicMin` functions in `common.h` for better alignment and readability.
- Adjusted `debug_print_var` function in `debug.h` to improve code structure and maintainability.
- Enhanced readability of the `atomic_add` function in `customize.py` by breaking long lines for better clarity.

* Remove debug print statements from `copy.cc` and `inject_tma_barrier.cc` to enhance code clarity and maintainability.

* [Enhancement] Disable reuse of small arrays in shared memory allocation

- Added logic to prevent the reuse of small arrays (<= 32 bits) in `merge_shared_memory_allocations.cc`, ensuring they are lowered to registers in LLVM for improved performance and memory management.

* Refactor `setup.py` to remove duplicate logging statements and enhance clarity. Update `finalize_reducer` function documentation in `reduce.py` to include detailed parameter and return descriptions, improving code readability and maintainability.

* Refactor `finalize_reducer` and `reduce` functions to remove redundant target checks. Simplified conditionals by retaining only the `TargetIsHopper` check, enhancing code clarity and maintainability.

* bug fix

* Add thread checks workaround for replicated cases

* Remove the is_one check

* fix lint error

* lint fix

* Update autotune tests to use smaller matrix sizes for improved performance and reliability

* [Refactor] Update FinalizeReducer to FinalizeReducerOp and adjust related methods

- Refactored FinalizeReducer class to FinalizeReducerOp, updating constructor and method signatures for consistency with the new TileOperator structure.
- Enhanced layout inference and cloning methods in FinalizeReducerOpNode.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main.
- Adjusted header inclusions for improved organization and clarity across multiple files.

* [Refactor] Update atomic operations in common.h and modify test_example_flash_attention.py

- Enhanced atomic operations (Add, Min, Max) in common.h to handle half and bfloat16 types more efficiently.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main, improving test organization.

* [Refactor] Simplify CopyNode::LowerBulkCopy logic and update test execution

- Removed redundant checks for contiguous memory access in CopyNode::LowerBulkCopy, streamlining the logic for TMA copy operations.
- Updated test_tilelang_kernel_gemm.py to comment out the main testing function and call a specific test for i8i8i32 tensor operations instead, improving test focus.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
Co-authored-by: default avatarFreebase6912 <amid-gauze-racing@duck.com>
parent b38bd69e
...@@ -54,9 +54,12 @@ from .reduce import ( ...@@ -54,9 +54,12 @@ from .reduce import (
reduce_abssum, # noqa: F401 reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401 reduce_absmax, # noqa: F401
cumsum, # noqa: F401 cumsum, # noqa: F401
finalize_reducer, # noqa: F401
) )
from .print import print # noqa: F401 from .print import print # noqa: F401
from .customize import ( from .customize import (
atomic_max, # noqa: F401
atomic_min, # noqa: F401
atomic_add, # noqa: F401 atomic_add, # noqa: F401
atomic_addx2, # noqa: F401 atomic_addx2, # noqa: F401
atomic_addx4, # noqa: F401 atomic_addx4, # noqa: F401
...@@ -64,6 +67,8 @@ from .customize import ( ...@@ -64,6 +67,8 @@ from .customize import (
clamp, # noqa: F401 clamp, # noqa: F401
reshape, # noqa: F401 reshape, # noqa: F401
view, # noqa: F401 view, # noqa: F401
atomic_load, # noqa: F401
atomic_store, # noqa: F401
) )
from .logical import any_of, all_of # noqa: F401 from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
......
...@@ -14,6 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object ...@@ -14,6 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope. with the appropriate memory scope.
""" """
from tilelang import tvm as tvm
from tvm.script import tir as T from tvm.script import tir as T
......
...@@ -7,6 +7,15 @@ from tvm import ir ...@@ -7,6 +7,15 @@ from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
from typing import List, Union from typing import List, Union
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
"consume": 1,
"acquire": 2,
"release": 3,
"acq_rel": 4,
"seq_cst": 5,
}
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""Create a memory region descriptor for tile operations. """Create a memory region descriptor for tile operations.
...@@ -83,7 +92,41 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, ...@@ -83,7 +92,41 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr: def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic maximum operation.
Args:
dst (Buffer): Destination buffer where the atomic maximum will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic maximum operation
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic minimum operation.
Args:
dst (Buffer): Destination buffer where the atomic minimum will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic minimum operation
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic addition operation. """Perform an atomic addition operation.
Args: Args:
...@@ -93,10 +136,6 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr: ...@@ -93,10 +136,6 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
Returns: Returns:
PrimExpr: Handle to the atomic addition operation PrimExpr: Handle to the atomic addition operation
""" """
if isinstance(dst, BufferLoad) and isinstance(value, BufferLoad):
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
def get_extent(data): def get_extent(data):
if isinstance(data, Var) and T.has_let_value(data): if isinstance(data, Var) and T.has_let_value(data):
...@@ -110,6 +149,17 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr: ...@@ -110,6 +149,17 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
src_extent = get_extent(value) src_extent = get_extent(value)
dst_extent = get_extent(dst) dst_extent = get_extent(dst)
if dst_extent is None and src_extent is None:
if memory_order is None:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
...@@ -217,3 +267,32 @@ def view(src: Buffer, ...@@ -217,3 +267,32 @@ def view(src: Buffer,
if dtype is None: if dtype is None:
dtype = src.dtype dtype = src.dtype
return T.Tensor(shape, dtype, src.data) return T.Tensor(shape, dtype, src.data)
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
"""Loads a value from the input buffer with specified memory_order.
Args:
src (Buffer): Input buffer to load from
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst".
Returns:
PrimExpr: The loaded value from the buffer
"""
return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src),
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""Stores a value to the input buffer with specified memory_order.
Args:
dst (Buffer): Input buffer to store to
src (PrimExpr): Value to store
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst".
Returns:
PrimExpr: The handle of the store operation
"""
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src,
_MEMORY_ORDER_ID_MAP[memory_order])
...@@ -185,3 +185,19 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve ...@@ -185,3 +185,19 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
dim, dim,
reverse, reverse,
) )
def finalize_reducer(reducer: tir.Buffer):
"""Finalize the reducer buffer.
Args:
reducer (tir.Buffer): The reducer buffer
Returns:
tir.Call: Handle to the finalize reducer operation
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.finalize_reducer"),
reducer.access_ptr("w"),
)
...@@ -419,3 +419,9 @@ def LowerDeviceKernelLaunch(): ...@@ -419,3 +419,9 @@ def LowerDeviceKernelLaunch():
"""LowerDeviceKernelLaunch """LowerDeviceKernelLaunch
""" """
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
def LayoutReducer():
"""LayoutReducer
"""
return _ffi_api.LayoutReducer() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment