"docs/tutorial_2_RemoteMachineMode.md" did not exist on "8d866b5b9dbf7d9b2ed4ea85f9133ccf1fec62bc"
Unverified Commit 4ca6c131 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[CI]:Reduce test shapes to avoid OOM errors during CI. (#1060)



* [CI]:Reduce test shapes to avoid OOM errors during CI.

* rabbit

* Increase number of processes for pytest from 2 to 4

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 759c2e33
...@@ -332,7 +332,7 @@ jobs: ...@@ -332,7 +332,7 @@ jobs:
uv run --no-project -m -- uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
) )
"${PYTEST[@]}" --maxfail=3 --numprocesses=2 \ "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
../examples ../examples
# NVIDIA CUDA tests # NVIDIA CUDA tests
......
...@@ -24,11 +24,27 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -24,11 +24,27 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main() example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
def test_example_triton_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main() example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -161,8 +161,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -161,8 +161,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8 return x_fp8
def main(): def main(M=8192, N=8192, BG=2, blk_m=8):
M, N, BG, blk_m = 8192, 8192, 2, 8
if dtype == "float": if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == "float16":
......
...@@ -79,8 +79,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...@@ -79,8 +79,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_fp8, (x_amax / 448.0).view(m, -1) return x_fp8, (x_amax / 448.0).view(m, -1)
def main(): def main(M=8192, N=8192, blk_m=8):
M, N, blk_m = 8192, 8192, 8
kernel = per_token_cast_to_fp8(M, N, blk_m) kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
...@@ -4,11 +4,11 @@ import example_per_token_cast_to_fp8 ...@@ -4,11 +4,11 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main() example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main() example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,7 +13,7 @@ def test_example_topk_selector(): ...@@ -13,7 +13,7 @@ def test_example_topk_selector():
def test_example_fp8_lighting_indexer(): def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer() test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
......
...@@ -96,8 +96,7 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp ...@@ -96,8 +96,7 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
def main(): def main(M=16384, N=16384, K=16384):
M, N, K = 16384, 16384, 16384
block_M, block_N, block_K = 128, 128, 32 block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16" in_dtype, out_dtype = "float16", "float16"
......
...@@ -3,7 +3,7 @@ import example_dynamic ...@@ -3,7 +3,7 @@ import example_dynamic
def test_example_dynamic(): def test_example_dynamic():
example_dynamic.main() example_dynamic.main(M=1024, N=1024, K=1024)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -44,12 +44,14 @@ def test_example_mha_bwd_wgmma_pipelined(): ...@@ -44,12 +44,14 @@ def test_example_mha_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined(): def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main() example_gqa_fwd_bshd_wgmma_pipelined.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd(): def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main() example_gqa_fwd_bshd.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
......
...@@ -52,8 +52,8 @@ def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32): ...@@ -52,8 +52,8 @@ def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
def test_pipeline_large_matrix(): def test_pipeline_large_matrix():
"""Test pipeline stages with large matrix multiplication (8192x8192)""" """Test pipeline stages with large matrix multiplication (4096x4096)"""
run_gemm_pipeline_test(8192) run_gemm_pipeline_test(4096)
def test_pipeline_small_matrix(): def test_pipeline_small_matrix():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment