Unverified Commit c61971e8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Add buffer load copy functions and improve copy logic in tilelang (#946)

- Introduced new functions for buffer load copy with stride and parallel execution.
- Enhanced the copy logic in `copy.py` to simplify nested if statements for BufferLoad nodes.
- Added corresponding test cases for the new buffer load functionalities.
parent 91d5ef54
...@@ -86,5 +86,74 @@ def test_tilelang_copy_with_stride(): ...@@ -86,5 +86,74 @@ def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128) run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
def tilelang_copy_bufferload(num_tokens, dtype="float16"):
@T.prim_func
def main(
indices: T.Tensor((num_tokens,), "int32"),
x: T.Tensor((num_tokens,), dtype),
):
with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], "int32")
T.copy(indices[pid], idx[0])
x[idx[0]] = x[idx[0]] + 1
return main
def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
program = tilelang_copy_bufferload(num_tokens, dtype)
# test compilation only
tilelang.compile(
program,
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
def test_tilelang_copy_bufferload():
run_tilelang_copy_bufferload(num_tokens=128)
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j])
return main
def run_tilelang_copy_buffer_load_with_parallel(M=1024,
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy_buffer_load_with_parallel():
run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -45,6 +45,14 @@ def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion], ...@@ -45,6 +45,14 @@ def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
src_extent = get_extent(src) src_extent = get_extent(src)
dst_extent = get_extent(dst) dst_extent = get_extent(dst)
# Combine the nested if statements into a single if statement as suggested by SIM102
if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and
isinstance(dst, tir.BufferLoad)):
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
return tir.BufferStore(dst.buffer, src, dst.indices)
assert src_extent or dst_extent, "Can't deduce copy extents from args" assert src_extent or dst_extent, "Can't deduce copy extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment