Unverified Commit 6afc28a9 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Test] Batch Invariant: Unit test using parameterized backend (#27478)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 141e6a05
...@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif( ...@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def enable_batch_invariant_mode(): def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests.""" """Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_BATCH_INVARIANT") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
os.environ["VLLM_BATCH_INVARIANT"] = "1"
yield yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_BATCH_INVARIANT"] = old_value
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
...@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: ...@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@skip_unsupported @skip_unsupported
@pytest.mark.timeout(1000) @pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): @pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch
):
""" """
Ensures that the same request (the 'needle' prompt) yields identical output Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
...@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): ...@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
...@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output): ...@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
@pytest.mark.forked @pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) backend, monkeypatch: pytest.MonkeyPatch
os.environ["VLLM_ATTENTION_BACKEND"] = backend ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
...@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): ...@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@skip_unsupported @skip_unsupported
def test_simple_generation(): @pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
""" """
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
llm = LLM( llm = LLM(
...@@ -481,9 +491,14 @@ def test_simple_generation(): ...@@ -481,9 +491,14 @@ def test_simple_generation():
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
@pytest.mark.forked @pytest.mark.forked
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): def test_logprobs_without_batch_invariance_should_fail(
backend, monkeypatch: pytest.MonkeyPatch
):
""" """
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN. This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior It DISABLES batch invariance mode and expects to see non-deterministic behavior
...@@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): ...@@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_BATCH_INVARIANT") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
os.environ["VLLM_BATCH_INVARIANT"] = "0"
try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
...@@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): ...@@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
bs1_logprobs_per_prompt = [] bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = [] bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts): for idx, p in enumerate(prompts):
print( print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
outs = llm.generate([p], sp, use_tqdm=False) outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1 assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0]) step_logprobs, token_ids = _extract_step_logprobs(outs[0])
...@@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): ...@@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
pytest.fail(fail_msg) pytest.fail(fail_msg)
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_BATCH_INVARIANT"] = old_value
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked @pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(backend): def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch
):
""" """
Test that verifies decode logprobs match prefill logprobs. Test that verifies decode logprobs match prefill logprobs.
...@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend): ...@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
This ensures that the logprobs from decode are consistent with what This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix. we would get if we ran prefill on each prefix.
""" """
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
......
...@@ -753,13 +753,13 @@ def override_envs_for_invariance(): ...@@ -753,13 +753,13 @@ def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = [ supported_backends = [
"FLASH_ATTN", # best supported backend "FLASH_ATTN", # best supported backend
"FLEX_ATTENTION",
"FLASHINFER", "FLASHINFER",
"FLASH_ATTN_MLA", "FLASH_ATTN_MLA",
"FLASHINFER_MLA", "FLASHINFER_MLA",
"TRITON_MLA", "TRITON_MLA",
# Not yet supported MLA backends # Not yet supported MLA backends
# "FLASHMLA", # "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
] ]
if curr_attn_backend not in supported_backends: if curr_attn_backend not in supported_backends:
warning = ( warning = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment