Unverified Commit 53c73028 authored by RUTHLESS-BOT's avatar RUTHLESS-BOT Committed by GitHub
Browse files

[Misc] parametrize 'dtype' in test_flash_mla (#22641)


Signed-off-by: default avatarRUTHLESS-BOT <wujiafeng@cmbchina.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 6534d2fc
...@@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ ...@@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen): varlen, dtype):
# TODO: parametrize using pytest
dtype = torch.bfloat16
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, ...@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
random.seed(0) random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}") f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen: if varlen:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment