Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6afc28a9
Unverified
Commit
6afc28a9
authored
Oct 28, 2025
by
Wentao Ye
Committed by
GitHub
Oct 28, 2025
Browse files
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
141e6a05
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
208 additions
and
204 deletions
+208
-204
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+207
-203
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
6afc28a9
...
...
@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
@
pytest
.
fixture
(
autouse
=
True
)
def
enable_batch_invariant_mode
():
def
enable_batch_invariant_mode
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"1"
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"1"
)
yield
# Restore original value after test
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
...
...
@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
...
...
@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
...
...
@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
...
...
@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@
skip_unsupported
def
test_simple_generation
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
llm
=
LLM
(
...
...
@@ -481,9 +491,14 @@ def test_simple_generation():
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
def
test_logprobs_without_batch_invariance_should_fail
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
...
...
@@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# CRITICAL: Disable batch invariance for this test
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"0"
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"0"
)
try
:
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
...
...
@@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
bs1_logprobs_per_prompt
=
[]
bs1_tokens_per_prompt
=
[]
for
idx
,
p
in
enumerate
(
prompts
):
print
(
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
)
print
(
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
)
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
assert
len
(
outs
)
==
1
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
outs
[
0
])
...
...
@@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
print
(
f
"
{
'='
*
80
}
\n
"
)
pytest
.
fail
(
fail_msg
)
finally
:
# Restore original value
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
forked
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
def
test_decode_logprobs_match_prefill_logprobs
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that verifies decode logprobs match prefill logprobs.
...
...
@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
6afc28a9
...
...
@@ -753,13 +753,13 @@ def override_envs_for_invariance():
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
supported_backends
=
[
"FLASH_ATTN"
,
# best supported backend
"FLEX_ATTENTION"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
,
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
]
if
curr_attn_backend
not
in
supported_backends
:
warning
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment