Commit 12714155 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Add reduce_max corresponding tests (#236)

* [Feature] Add reduce_max functionality and corresponding tests

* Introduced a new test file for the reduce_max operation in the tilelang language module.
* Implemented the reduce_max functionality using T.prim_func, including local memory allocation and result copying.
* Added tests for various input sizes and data types to ensure correctness of the reduce_max implementation.
* Enhanced profiling assertions to validate the output against reference implementations.

* Fix whitespace issues in reduce_max test file for improved readability
parent 622fa042
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def reduce_max_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((M, N), dtype),
B: T.Buffer((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_max operation
T.reduce_max(A_local, B_local, dim=1)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_max(M, N, dtype="float16"):
program = reduce_max_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.max(dim=1).values
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_max():
# Test different sizes
run_reduce_max(256, 256)
run_reduce_max(512, 128)
run_reduce_max(128, 512)
# Test different dtypes
run_reduce_max(256, 256, "float32")
run_reduce_max(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment