Commit a636debb authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[Pytest Fix] Wrap tests in dynamic benchmark (#387)

[Dynamic Symbolic] Add pass_config to customize vectorization and tail split
[Pytest Fix] Wrap tests in dynamic benchmark
parent 280e6627
...@@ -3,6 +3,7 @@ import torch.backends ...@@ -3,6 +3,7 @@ import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
import pytest
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
tilelang.disable_cache() tilelang.disable_cache()
...@@ -448,11 +449,13 @@ def assert_tl_matmul_block_dynamic_mnk( ...@@ -448,11 +449,13 @@ def assert_tl_matmul_block_dynamic_mnk(
print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms") print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms")
@pytest.mark.skip("Skip static test")
def test_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): def test_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16",
"float16", "float32") "float16", "float32")
@pytest.mark.skip("Skip dynamic m test")
def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_m( assert_tl_matmul_block_dynamic_m(
M, M,
...@@ -485,6 +488,7 @@ def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -485,6 +488,7 @@ def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mn test")
def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mn( assert_tl_matmul_block_dynamic_mn(
M, M,
...@@ -517,6 +521,7 @@ def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -517,6 +521,7 @@ def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mnk test")
def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mnk( assert_tl_matmul_block_dynamic_mnk(
M, M,
...@@ -549,7 +554,7 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -549,7 +554,7 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False})
def assert_all(): def test_all():
test_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32) test_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32) test_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32) test_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
...@@ -557,4 +562,4 @@ def assert_all(): ...@@ -557,4 +562,4 @@ def assert_all():
if __name__ == "__main__": if __name__ == "__main__":
assert_all() tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment