Commit 2fff0eec authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Delete testing/python/language/test_tilelang_language_reduce_sum.py (#424)

parent 88747fcd
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
tilelang.disable_cache()
def reduce_sum_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=0)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_sum(M, N, dtype="float16"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.sum(dim=0)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)
# Test different dtypes
run_reduce_sum(256, 256, "float32")
run_reduce_sum(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment