"examples/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "ea5483016d0bb4ea27c7b695ef3a9d66e44c493c"
Unverified Commit 13bdcd60 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Refactor]: Change the params in pytest to avoid oom error during ci (#1170)

* [Refactor]: Change the params in pytest to avoid oom error during ci

* format

* fix

* Update test_example_cast.py

* Update parameters in test_example_cast

* Update test_example_flash_attention.py

* update

* format

* fix

* fix

* format
parent 5c62d00a
...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16, batch=8,
heads=16, heads=8,
heads_kv=8, heads_kv=4,
max_cache_seqlen=4096, max_cache_seqlen=2048,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask(): ...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch=16, batch=16,
heads=16, heads=16,
heads_kv=8, heads_kv=8,
max_cache_seqlen=4096, max_cache_seqlen=1024,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
......
...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8 return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8): def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float": if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == "float16":
...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8): ...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32) batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128) M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes) print("batch_sizes:", batch_sizes)
......
...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8 ...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8) example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8) example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,7 +13,7 @@ def test_example_topk_selector(): ...@@ -13,7 +13,7 @@ def test_example_topk_selector():
def test_example_fp8_lighting_indexer(): def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1) test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd(): ...@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd():
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd_pipelined( test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
test_sparse_mla_bwd( test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined(): ...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd(): def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1) example_mha_bwd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd(): def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1) example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined(): def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1) example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd(): ...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen(): def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main() example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main(): def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if causal: if causal:
......
...@@ -12,7 +12,7 @@ def test_example_example_gqa_decode(): ...@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
def test_example_example_mha_inference(): def test_example_example_mha_inference():
example_mha_inference.main() example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment