Unverified Commit 8eab7755 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Reducer] Introduce `alloc_reducer` to separate inter and intra warp reduction (#757)



* [Enhancement] Introduce finalize_reducer operator and layout reducer support

- Added `FinalizeReducer` operator to handle reduction finalization in the TileLang framework, allowing for efficient reduction operations.
- Implemented layout inference for local.reducer buffers, enhancing the handling of layout mappings and reducing complexity in buffer management.
- Updated `setup.py` to include logging for build directory paths, improving build process visibility.
- Enhanced atomic operations with new functions for atomic max, min, load, and store, providing more robust atomicity control in memory operations.
- Refactored parallel loop handling to incorporate reducer information, ensuring proper management of reduction operations in parallel contexts.
- Cleaned up test cases by removing unnecessary cache disabling and optimizing test parameters for better performance.

* Refactor code formatting and improve readability in multiple files

- Cleaned up whitespace in `setup.py` to enhance logging clarity.
- Reformatted `AtomicMax` and `AtomicMin` functions in `common.h` for better alignment and readability.
- Adjusted `debug_print_var` function in `debug.h` to improve code structure and maintainability.
- Enhanced readability of the `atomic_add` function in `customize.py` by breaking long lines for better clarity.

* Remove debug print statements from `copy.cc` and `inject_tma_barrier.cc` to enhance code clarity and maintainability.

* [Enhancement] Disable reuse of small arrays in shared memory allocation

- Added logic to prevent the reuse of small arrays (<= 32 bits) in `merge_shared_memory_allocations.cc`, ensuring they are lowered to registers in LLVM for improved performance and memory management.

* Refactor `setup.py` to remove duplicate logging statements and enhance clarity. Update `finalize_reducer` function documentation in `reduce.py` to include detailed parameter and return descriptions, improving code readability and maintainability.

* Refactor `finalize_reducer` and `reduce` functions to remove redundant target checks. Simplified conditionals by retaining only the `TargetIsHopper` check, enhancing code clarity and maintainability.

* bug fix

* Add thread checks workaround for replicated cases

* Remove the is_one check

* fix lint error

* lint fix

* Update autotune tests to use smaller matrix sizes for improved performance and reliability

* [Refactor] Update FinalizeReducer to FinalizeReducerOp and adjust related methods

- Refactored FinalizeReducer class to FinalizeReducerOp, updating constructor and method signatures for consistency with the new TileOperator structure.
- Enhanced layout inference and cloning methods in FinalizeReducerOpNode.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main.
- Adjusted header inclusions for improved organization and clarity across multiple files.

* [Refactor] Update atomic operations in common.h and modify test_example_flash_attention.py

- Enhanced atomic operations (Add, Min, Max) in common.h to handle half and bfloat16 types more efficiently.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main, improving test organization.

* [Refactor] Simplify CopyNode::LowerBulkCopy logic and update test execution

- Removed redundant checks for contiguous memory access in CopyNode::LowerBulkCopy, streamlining the logic for TMA copy operations.
- Updated test_tilelang_kernel_gemm.py to comment out the main testing function and call a specific test for i8i8i32 tensor operations instead, improving test focus.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
Co-authored-by: default avatarFreebase6912 <amid-gauze-racing@duck.com>
parent b38bd69e
......@@ -54,9 +54,12 @@ from .reduce import (
reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401
cumsum, # noqa: F401
finalize_reducer, # noqa: F401
)
from .print import print # noqa: F401
from .customize import (
atomic_max, # noqa: F401
atomic_min, # noqa: F401
atomic_add, # noqa: F401
atomic_addx2, # noqa: F401
atomic_addx4, # noqa: F401
......@@ -64,6 +67,8 @@ from .customize import (
clamp, # noqa: F401
reshape, # noqa: F401
view, # noqa: F401
atomic_load, # noqa: F401
atomic_store, # noqa: F401
)
from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401
......
......@@ -14,6 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from tilelang import tvm as tvm
from tvm.script import tir as T
......
......@@ -7,6 +7,15 @@ from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
from typing import List, Union
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
"consume": 1,
"acquire": 2,
"release": 3,
"acq_rel": 4,
"seq_cst": 5,
}
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""Create a memory region descriptor for tile operations.
......@@ -83,7 +92,41 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic maximum operation.
Args:
dst (Buffer): Destination buffer where the atomic maximum will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic maximum operation
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic minimum operation.
Args:
dst (Buffer): Destination buffer where the atomic minimum will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic minimum operation
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic addition operation.
Args:
......@@ -93,10 +136,6 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
Returns:
PrimExpr: Handle to the atomic addition operation
"""
if isinstance(dst, BufferLoad) and isinstance(value, BufferLoad):
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
def get_extent(data):
if isinstance(data, Var) and T.has_let_value(data):
......@@ -110,6 +149,17 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
src_extent = get_extent(value)
dst_extent = get_extent(dst)
if dst_extent is None and src_extent is None:
if memory_order is None:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
......@@ -217,3 +267,32 @@ def view(src: Buffer,
if dtype is None:
dtype = src.dtype
return T.Tensor(shape, dtype, src.data)
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
"""Loads a value from the input buffer with specified memory_order.
Args:
src (Buffer): Input buffer to load from
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst".
Returns:
PrimExpr: The loaded value from the buffer
"""
return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src),
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""Stores a value to the input buffer with specified memory_order.
Args:
dst (Buffer): Input buffer to store to
src (PrimExpr): Value to store
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst".
Returns:
PrimExpr: The handle of the store operation
"""
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src,
_MEMORY_ORDER_ID_MAP[memory_order])
......@@ -185,3 +185,19 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
dim,
reverse,
)
def finalize_reducer(reducer: tir.Buffer):
"""Finalize the reducer buffer.
Args:
reducer (tir.Buffer): The reducer buffer
Returns:
tir.Call: Handle to the finalize reducer operation
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.finalize_reducer"),
reducer.access_ptr("w"),
)
......@@ -419,3 +419,9 @@ def LowerDeviceKernelLaunch():
"""LowerDeviceKernelLaunch
"""
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
def LayoutReducer():
"""LayoutReducer
"""
return _ffi_api.LayoutReducer() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment