"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "983d27edff4813600ed648687113da46d0eeeec7"
Commit 73ae8087 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Refactor] Update main function structure in example scripts and add tests (#475)

* [Refactor] Update example_mla_decode.py and add tests for block_sparse_attn_tilelang

* Refactor example_mla_decode.py to define a main function for better structure and clarity.
* Introduce test_example_mla_decode.py to validate the functionality of example_mla_decode.
* Refactor block_sparse_attn_tilelang.py to define a main function and add test_block_sparse_attn_tilelang.py for testing.
* Ensure all new test files are integrated with tilelang testing framework.

* [Test] Enhance test_example_mla_decode with argument mocking

* Update test_example_mla_decode.py to mock sys.argv for better test isolation.
* Ensure the main function of example_mla_decode is called with the correct arguments during testing.
parent dca2fb48
...@@ -270,7 +270,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -270,7 +270,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
return out return out
if __name__ == "__main__": def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument('--heads', type=int, default=128, help='q heads number')
...@@ -294,3 +294,7 @@ if __name__ == "__main__": ...@@ -294,3 +294,7 @@ if __name__ == "__main__":
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops") print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
main()
\ No newline at end of file
import tilelang.testing
import example_mla_decode
from unittest import mock
import sys
@tilelang.testing.requires_cuda
def test_example_mla_decode():
with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]):
example_mla_decode.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -259,6 +259,10 @@ def test_topk_sparse_attention_qlen_lt_klen(): ...@@ -259,6 +259,10 @@ def test_topk_sparse_attention_qlen_lt_klen():
print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
if __name__ == "__main__": def main():
test_topk_sparse_attention() test_topk_sparse_attention()
test_topk_sparse_attention_qlen_lt_klen() test_topk_sparse_attention_qlen_lt_klen()
if __name__ == "__main__":
main()
\ No newline at end of file
import tilelang.testing
import block_sparse_attn_tilelang
@tilelang.testing.requires_cuda
def test_block_sparse_attn_tilelang():
block_sparse_attn_tilelang.main()
if __name__ == "__main__":
tilelang.testing.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment