Unverified Commit 869f021b authored by Dayuxiaoshui's avatar Dayuxiaoshui Committed by GitHub
Browse files

[Feature] Support region as input of T.cumsum (#1426)



* [Feature] Support region as input of T.cumsum

- Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer
- This enables operations on buffer slices/regions like:
  T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
- Update cumsum_fragment to handle region inputs properly
- Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes

Fixes #879

* Fix formatting and add docstring for cumsum_fragment

- Add comprehensive docstring for cumsum_fragment function
- Format code according to ruff style guidelines

* Fix CodeRabbit review issues

- Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape))
- Add src/dst shape compatibility validation for out-of-place cumsum
- Update copy() type annotation to accept BufferRegion as dst parameter
- Fix test in-place mutation issues by using out-of-place cumsum operations
- Add non-divisible size test cases for tail region coverage

* Fix out-of-bounds access in region tests

- Add bounds clamping using T.min() for chunk_end calculations
- Prevents accessing beyond tensor bounds for non-divisible sizes
- Matches reference implementation behavior
- Fixes both 1D and 2D region test cases

* Fix region test: use simple slice expressions instead of T.min()

- Remove T.min() which cannot be used directly in slice indices
- Use chunk_start + chunk_size form instead
- Rely on system's automatic bounds checking for non-divisible sizes
- Update comments to reflect this approach

* Fix cumsum region: use region extents in lowering and update tests for shared memory

* Simplify fragment scope check using is_fragment()

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent bcae814e
...@@ -528,23 +528,29 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -528,23 +528,29 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss; std::stringstream ss;
auto threads = T.thread_bounds->extent; auto threads = T.thread_bounds->extent;
Array<PrimExpr> args; Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size());
// Build access pointers from regions locally // Build access pointers from regions locally
PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1); PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2); PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);
// Use region extents instead of buffer shape for correct slice handling
Array<PrimExpr> src_extents;
for (const auto &range : srcRegion_->region) {
src_extents.push_back(range->extent);
}
int ndim = static_cast<int>(src_extents.size());
if (ndim == 1) { if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0."; "= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run"; << ">::run";
args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]}; args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0]};
} else if (ndim == 2) { } else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", " ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run"; << (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0], args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0],
src->shape[1]}; src_extents[1]};
} else { } else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D."; << ndim << "D.";
......
...@@ -174,5 +174,139 @@ def test_cumsum_fragment_1d(): ...@@ -174,5 +174,139 @@ def test_cumsum_fragment_1d():
run_cumsum_1d(1024, 128, reverse=True, scope="fragment") run_cumsum_1d(1024, 128, reverse=True, scope="fragment")
def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype="float32"):
"""Test cumsum with buffer region (slice) as input."""
import tilelang.language as T
@T.prim_func
def cumsum_region(
InputG_fragment: T.Tensor((N,), dtype),
OutputG_fragment: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, chunk_size), threads=chunk_size) as bx:
i = bx
chunk_start = i * chunk_size
# Copy region to shared memory first (cumsum only supports shared memory)
A_shared = T.alloc_shared((chunk_size,), dtype)
T.copy(InputG_fragment[chunk_start : chunk_start + chunk_size], A_shared)
# Test cumsum with region input - in-place operation on shared memory
# This demonstrates the feature: T.cumsum(region, dim=0)
T.cumsum(src=A_shared, dim=0, reverse=reverse)
# Copy result back to global memory
T.copy(A_shared, OutputG_fragment[chunk_start : chunk_start + chunk_size])
return cumsum_region
def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype="float32"):
"""Run test for cumsum with region input."""
program = cumsum_region_test_1d(N, chunk_size, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
A = torch.randn(N, dtype=getattr(torch, dtype)).cuda()
def ref_program(A):
ref_b = torch.empty_like(A)
num_blocks = (N + chunk_size - 1) // chunk_size
for j in range(num_blocks):
start = j * chunk_size
end = min(start + chunk_size, N)
chunk = A[start:end].clone()
if reverse:
chunk = torch.flip(chunk, dims=[0])
chunk = chunk.cumsum(dim=0)
if reverse:
chunk = torch.flip(chunk, dims=[0])
ref_b[start:end] = chunk
return ref_b
tilelang_res = jit_kernel(A)
ref_res = ref_program(A)
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"):
"""Test cumsum with buffer region (slice) as input in 2D."""
import tilelang.language as T
@T.prim_func
def cumsum_region(
InputG_fragment: T.Tensor((M, N), dtype),
OutputG_fragment: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
chunk_start_M = by * block_M
chunk_start_N = bx * block_N
# Copy region to shared memory first (cumsum only supports shared memory)
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(
InputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N],
A_shared,
)
# Test cumsum with 2D region input - in-place operation on shared memory
T.cumsum(src=A_shared, dim=dim, reverse=reverse)
# Copy result back to global memory
T.copy(
A_shared,
OutputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N],
)
return cumsum_region
def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"):
"""Run test for cumsum with 2D region input."""
program = cumsum_region_test_2d(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
def ref_program(A):
ref_b = torch.empty_like(A)
num_blocks_M = (M + block_M - 1) // block_M
num_blocks_N = (N + block_N - 1) // block_N
for i in range(num_blocks_M):
for j in range(num_blocks_N):
start_M = i * block_M
end_M = min(start_M + block_M, M)
start_N = j * block_N
end_N = min(start_N + block_N, N)
chunk = A[start_M:end_M, start_N:end_N].clone()
if reverse:
chunk = torch.flip(chunk, dims=[dim])
chunk = chunk.cumsum(dim=dim)
if reverse:
chunk = torch.flip(chunk, dims=[dim])
ref_b[start_M:end_M, start_N:end_N] = chunk
return ref_b
tilelang_res = jit_kernel(A)
ref_res = ref_program(A)
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def test_cumsum_region_1d():
"""Test cumsum with 1D region input."""
# Test normal cumsum with region input
run_cumsum_region_1d(1024, 128)
# Test reverse cumsum with region input
run_cumsum_region_1d(1024, 128, reverse=True)
# Test with different chunk sizes
run_cumsum_region_1d(512, 64)
run_cumsum_region_1d(2048, 256)
# Tail coverage (non-divisible size)
run_cumsum_region_1d(1000, 128)
def test_cumsum_region_2d():
"""Test cumsum with 2D region input."""
# Test 2D cumsum along dim 0
run_cumsum_region_2d(1024, 1024, 128, 128, dim=0)
# Test 2D cumsum along dim 1
run_cumsum_region_2d(1024, 1024, 128, 128, dim=1)
# Test reverse cumsum
run_cumsum_region_2d(512, 512, 64, 64, dim=1, reverse=True)
# Tail coverage (non-divisible size)
run_cumsum_region_2d(1000, 1000, 128, 128, dim=1)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -13,7 +13,7 @@ from tvm import ir, tir ...@@ -13,7 +13,7 @@ from tvm import ir, tir
def copy( def copy(
src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
dst: tir.Buffer | tir.BufferLoad, dst: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
coalesced_width: int | None = None, coalesced_width: int | None = None,
disable_tma: bool = False, disable_tma: bool = False,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
...@@ -22,7 +22,7 @@ def copy( ...@@ -22,7 +22,7 @@ def copy(
Args: Args:
src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region
dst (Union[tir.Buffer, tir.BufferLoad]): Destination memory region dst (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Destination memory region
coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None.
Raises: Raises:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.utils.language import to_buffer_region from tilelang.utils.language import to_buffer_region, retrieve_shape, _get_buffer
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import IRBuilder
...@@ -242,8 +242,35 @@ def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo ...@@ -242,8 +242,35 @@ def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
@macro @macro
def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr: def cumsum_fragment(
cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn") src: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
dim: int,
reverse: bool,
) -> tir.PrimExpr:
"""
Compute cumulative sum for fragment buffers by copying to shared memory first.
This macro handles cumulative sum operations on fragment buffers by first copying
the data to shared memory, performing the cumsum operation, and then copying back.
Args:
src: Source buffer (Buffer, BufferRegion, or BufferLoad) containing input data.
dst: Destination buffer (Buffer, BufferRegion, or BufferLoad) for output data.
dim: Dimension along which to compute cumulative sum.
reverse: If True, compute cumulative sum in reverse order.
Returns:
tir.PrimExpr: A handle to the cumulative sum operation.
"""
src_shape = retrieve_shape(src)
src_buffer = _get_buffer(src)
# Get dtype from the buffer
if isinstance(src, tir.Buffer):
dtype = src.dtype
else:
dtype = src_buffer.dtype
cumsum_smem = alloc_shared(src_shape, dtype, "shared.dyn")
copy(src, cumsum_smem) copy(src, cumsum_smem)
tir.call_intrin( tir.call_intrin(
"handle", "handle",
...@@ -256,12 +283,19 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -256,12 +283,19 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(cumsum_smem, dst) copy(cumsum_smem, dst)
def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse: bool = False): def cumsum(
src: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad | None = None,
dim: int = 0,
reverse: bool = False,
):
""" """
Compute the cumulative sum of `src` along `dim`, writing results to `dst`. Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
Supports Buffer, BufferRegion, and BufferLoad inputs, allowing operations on buffer slices/regions.
Examples: Examples:
A 1D inclusive scan that writes the result into a separate shared-memory buffer: A 1D inclusive scan that writes the result into a separate shared-memory buffer:
...@@ -285,19 +319,40 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse ...@@ -285,19 +319,40 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
... T.cumsum(src=tile, dim=1, reverse=True) ... T.cumsum(src=tile, dim=1, reverse=True)
... T.copy(tile, B) ... T.copy(tile, B)
Operating on a buffer region (slice):
>>> import tilelang.language as T
>>> @T.prim_func
... def kernel_region(InputG_fragment: T.Tensor((128,), "float32"), chunk_size: T.int32):
... with T.Kernel(1, threads=128):
... i = T.int32(0)
... T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
Returns: Returns:
tir.Call: A handle to the emitted cumulative-sum operation. tir.Call: A handle to the emitted cumulative-sum operation.
""" """
shape = src.shape # Get shape from src (supports Buffer, BufferRegion, BufferLoad)
if dim >= len(shape) or dim <= -len(shape): shape = retrieve_shape(src)
if dim >= len(shape) or dim < -len(shape):
raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}") raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}")
if dim < 0: if dim < 0:
dim = len(shape) + dim dim = len(shape) + dim
if dst is None: if dst is None:
dst = src dst = src
if src.scope() == "local.fragment": else:
# Validate that dst shape matches src shape
dst_shape = retrieve_shape(dst)
if len(dst_shape) != len(shape):
raise ValueError(f"cumsum dst shape {dst_shape} must match src shape {shape} (rank mismatch)")
# Check each dimension matches
for i in range(len(shape)):
if not tir.analysis.expr_deep_equal(dst_shape[i], shape[i]):
raise ValueError(f"cumsum dst shape {dst_shape} must match src shape {shape} (dim {i} mismatch)")
# Check if src is a fragment buffer
if is_fragment(src):
return cumsum_fragment(src, dst, dim, reverse) return cumsum_fragment(src, dst, dim, reverse)
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment