Unverified Commit f810f976 authored by LJC00118's avatar LJC00118 Committed by GitHub
Browse files

Add unit tests for T.assume (#1341)



* Add test for T.assume

* Add unit test for T.assume

* Add unit test for T.assume

* Add unit tests for T.assume

* Remove debug print for kernel source

Remove print statement for kernel source in tests.

* Update test_tilelang_language_assume.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent e2b10c58
import tilelang
import tilelang.language as T
import tilelang.testing
def test_assume_remove_boundary_check():
@tilelang.jit
def kernel_with_assume():
N = T.dynamic('N')
@T.prim_func
def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32):
with T.Kernel(1, threads=32) as _:
for i in T.serial(r - l + 1):
T.assume(l + i >= 0 and l + i < N)
A[l + i] = 0
return main
jit_kernel = kernel_with_assume()
source = jit_kernel.get_kernel_source()
assert ("if (" not in source)
def test_assume_enable_vectorization():
@tilelang.jit
def kernel_vectorize(M):
N = T.dynamic('N')
vectorize_size = 4
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()
base_idx = tid * 4
T.assume(N % vectorize_size == 0)
for i in T.vectorized(vectorize_size):
T.assume(base_idx + i < N)
B[tid, base_idx + i] = A[tid, base_idx + i]
return main
jit_kernel = kernel_vectorize(128)
source = jit_kernel.get_kernel_source()
assert ("float4" in source) and ("if (" not in source)
def test_assume_complex_indexing():
@tilelang.jit
def kernel_complex():
M = T.dynamic('M')
N = T.dynamic('N')
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()
for j in T.serial(N):
i_src = T.min(j + 233, tid + 2)
j_src = j * T.ceildiv(j, i_src) * j - 1
T.assume(i_src >= 0 and i_src < M)
T.assume(j_src >= 0 and j_src < N)
B[tid, j] = A[i_src, j_src]
return main
jit_kernel = kernel_complex()
source = jit_kernel.get_kernel_source()
assert ("if (" not in source)
if __name__ == '__main__':
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment