Unverified Commit 2c0072a8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Update buffer handling in copy and atomic operations (#1247)

* [Refactor] Update buffer handling in copy and atomic operations

* Refactored the `copy` and `atomic_add` functions to use element-wise minimum for defining copy extents, ensuring correct handling of overlapping regions.
* Updated utility functions to create `BufferLoad` instances with explicit extents, improving memory management and clarity.
* Removed unused imports from `atomic.py` and `copy.py` to streamline the codebase.
* Adjusted logging in `copy.cc` to provide clearer warnings for fallback scenarios in bulk copy operations.

* Remove obsolete .git_commit.txt file

* Add unit test for dynamic copy extent handling in TileLang

* Introduced a new test file `test_tilelang_issue_1237.py` to verify that the `T.copy` function correctly manages dynamic extents during primitive function building.
* The test reproduces a specific issue related to dynamic slice lengths and static buffer sizes, ensuring robustness in the handling of such scenarios.
* The test does not require execution of the kernel, as building the primitive function is sufficient to validate the fix.

* lint fix

* fix

* Revert "fix"

This reverts commit 828b4c1e4de76a7d11e4d4092927303fbbe00097.

* Update TVM submodule and refactor atomic and copy functions

* Updated the TVM submodule to a dirty state.
* Refactored `atomic_add` and `copy` functions to pass extents explicitly to the `_to_region` helper, improving clarity and correctness in handling buffer regions.
* Commented out the main execution call in the test example for `cast` and added a new function call to better demonstrate the example usage.

* Enhance extent handling in atomic and copy functions

* Introduced `legalize_pairwise_extents` utility to align and broadcast extent lists for `atomic_add` and `copy` functions, ensuring compatibility and correctness in buffer operations.
* Updated both functions to utilize the new utility, improving clarity and robustness in handling dynamic and static extents.
* Added comments to clarify the extent handling logic.

* Enhance `legalize_pairwise_extents` function with early-exit rule

* Added an early-exit condition to the `legalize_pairwise_extents` function to return original extents if the number of non-1 dimensions in both source and destination extents is equal, improving performance by avoiding unnecessary adjustments.
* Updated the function's documentation to clarify the new behavior and maintain clarity in the extent handling logic.

* lint fix
parent d7164abf
...@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
} }
auto inner_box_dim = as_const_int(desc.smem_box[0]); auto inner_box_dim = as_const_int(desc.smem_box[0]);
ICHECK(inner_box_dim != nullptr); if (inner_box_dim == nullptr) {
LOG(WARNING) << "inner_box_dim " << desc.smem_box[0]
<< " can only be a constant integer for TMA bulk copy, "
"fallback to normal copy";
return LowerNormalCopy(T, analyzer);
}
int instruction_dim = *inner_box_dim; int instruction_dim = *inner_box_dim;
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) { if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) {
instruction_dim = 64 / src->dtype.bytes(); instruction_dim = 64 / src->dtype.bytes();
......
import tilelang.testing
from tilelang import language as T
def test_issue_1237_dynamic_copy_extent_builds():
# Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test.
# The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
length = T.symbolic("len", dtype="int32")
@T.prim_func
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
with T.Kernel(1, threads=32):
buffer_shared = T.alloc_shared((1024,), dtype="int32")
T.copy(global_tensor[0:length], buffer_shared)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
_ = sample_kernel
if __name__ == "__main__":
tilelang.testing.main()
...@@ -6,8 +6,8 @@ from __future__ import annotations ...@@ -6,8 +6,8 @@ from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from tvm import ir, tir from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load, legalize_pairwise_extents
_MEMORY_ORDER_ID_MAP = { _MEMORY_ORDER_ID_MAP = {
"relaxed": 0, "relaxed": 0,
...@@ -201,13 +201,14 @@ def atomic_add(dst: Buffer, ...@@ -201,13 +201,14 @@ def atomic_add(dst: Buffer,
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type): def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data): if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data) data = T.get_let_value(data)
if isinstance(data, tir.Buffer): if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type) zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion): elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent) return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad): elif isinstance(data, tir.BufferLoad):
...@@ -218,8 +219,8 @@ def atomic_add(dst: Buffer, ...@@ -218,8 +219,8 @@ def atomic_add(dst: Buffer,
else: else:
return buffer_load_to_tile_region(data, access_type, extent) return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r") value = _to_region(value, "r", src_extent)
dst = _to_region(dst, "w") dst = _to_region(dst, "w", dst_extent)
# Note: tile-region-based atomic operations don't support return_prev yet # Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime # This would need to be implemented in the tile runtime
......
...@@ -3,9 +3,12 @@ from __future__ import annotations ...@@ -3,9 +3,12 @@ from __future__ import annotations
from typing import Literal from typing import Literal
from tilelang import language as T from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import (
get_buffer_region_from_load,
legalize_pairwise_extents,
)
from tvm import ir, tir from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
...@@ -55,15 +58,26 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -55,15 +58,26 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
return tir.BufferStore(dst.buffer, src, dst.indices) return tir.BufferStore(dst.buffer, src, dst.indices)
assert src_extent or dst_extent, "Can't deduce copy extents from args" assert src_extent or dst_extent, "Can't deduce copy extents from args"
# Treat missing extent as length-matched ones to enable broadcasting logic.
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type): # Align and broadcast extents from the right (tail) side independently
# for src and dst, so we can pass them unchanged into _to_region.
# Rules per-dim from the right:
# - equal -> keep both
# - one is 1 -> set that side to the other side's dim
# - otherwise -> error
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data): if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data) data = T.get_let_value(data)
if isinstance(data, tir.Buffer): if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type) # Restrict a raw buffer to the computed copy extent by creating
# a BufferLoad at origin and passing the extents explicitly.
zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion): elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent) return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad): elif isinstance(data, tir.BufferLoad):
...@@ -74,8 +88,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -74,8 +88,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
else: else:
return buffer_load_to_tile_region(data, access_type, extent) return buffer_load_to_tile_region(data, access_type, extent)
src = _to_region(src, "r") # Use legalized extents for src and dst respectively.
dst = _to_region(dst, "w") src = _to_region(src, "r", src_extent)
dst = _to_region(dst, "w", dst_extent)
if coalesced_width is None: if coalesced_width is None:
coalesced_width = -1 # PrimExpr can not be None coalesced_width = -1 # PrimExpr can not be None
......
...@@ -85,7 +85,14 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s ...@@ -85,7 +85,14 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
extents extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) # Clamp extents element-wise so that the produced region respects the
# requested copy/fill extent, supporting dynamic PrimExpr via tir.min.
clamped_extents = [
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
for i in range(len(region_extents))
]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
def index_to_coordinates(index, shape) -> list[PrimExpr]: def index_to_coordinates(index, shape) -> list[PrimExpr]:
......
...@@ -367,6 +367,52 @@ def prim_expr_equal(lhs, rhs) -> bool: ...@@ -367,6 +367,52 @@ def prim_expr_equal(lhs, rhs) -> bool:
return tir.analysis.expr_deep_equal(lhs, rhs) return tir.analysis.expr_deep_equal(lhs, rhs)
def legalize_pairwise_extents(src_extents: list, dst_extents: list) -> tuple[list, list]:
"""
Right-align and broadcast two extent lists to be mutually compatible.
Early-exit rule:
- If the number of non-1 dimensions in `src_extents` equals that in `dst_extents`,
no adjustment is made; the original extents are returned unchanged. This
preserves the per-dimension iteration mapping (one loop var per non-1 dim)
and avoids creating extra varying axes on either side.
Otherwise, for each pair of tail-aligned dimensions (x, y):
- if x == y: keep both
- elif x == 1: set x = y
- elif y == 1: set y = x
- else: promote both to tir.max(x, y) to handle dynamic-vs-static safely
Leading unmatched dimensions are kept as-is.
Returns a tuple of new lists (src_new, dst_new).
"""
a = list(src_extents)
b = list(dst_extents)
# If both sides have the same number of non-1 extents, don't re-broadcast.
def _num_non_one(exts: list) -> int:
return sum(0 if prim_expr_equal(x, 1) else 1 for x in exts)
if _num_non_one(a) == _num_non_one(b):
return a, b
k = min(len(a), len(b))
for i in range(1, k + 1):
x, y = a[-i], b[-i]
if prim_expr_equal(x, y):
continue
elif prim_expr_equal(x, 1):
a[-i] = y
elif prim_expr_equal(y, 1):
b[-i] = x
else:
# Dynamic mismatch: promote to max so downstream clamping/predicates remain safe
m = tir.max(x, y)
a[-i] = m
b[-i] = m
return a, b
def is_full_region(buffer_region: BufferRegion) -> bool: def is_full_region(buffer_region: BufferRegion) -> bool:
""" """
Check whether a BufferRegion covers the full buffer region. Check whether a BufferRegion covers the full buffer region.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment