Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1607e664
Unverified
Commit
1607e664
authored
Nov 19, 2025
by
Wentao Ye
Committed by
GitHub
Nov 19, 2025
Browse files
[Bug] Fix Batch Invariant MLA test (#28967)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
68d72319
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
10 deletions
+33
-10
tests/v1/determinism/test_batch_invariance.py
tests/v1/determinism/test_batch_invariance.py
+32
-9
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
No files found.
tests/v1/determinism/test_batch_invariance.py
View file @
1607e664
...
@@ -9,13 +9,33 @@ import torch
...
@@ -9,13 +9,33 @@ import torch
from
utils
import
_extract_step_logprobs
,
_random_prompt
,
skip_unsupported
from
utils
import
_extract_step_logprobs
,
_random_prompt
,
skip_unsupported
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.platforms
import
current_platform
BACKENDS
:
list
[
str
]
=
[
"FLASH_ATTN"
,
"FLASHINFER"
,
]
if
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
):
BACKENDS
.
append
(
"FLASH_ATTN_MLA"
)
DEFAULT_MODEL
=
"Qwen/Qwen3-1.7B"
MLA_MODEL
=
"deepseek-ai/DeepSeek-V2-Lite-Chat"
def
resolve_model_name
(
backend
:
str
)
->
str
:
"""Resolve the model name for the given backend, respecting env overrides."""
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
DEFAULT_MODEL
)
if
backend
.
endswith
(
"MLA"
)
and
model
==
DEFAULT_MODEL
:
return
MLA_MODEL
return
model
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
@
pytest
.
mark
.
timeout
(
1000
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"backend"
,
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
]
,
BACKENDS
,
)
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
...
@@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# Allow overrides from environment (useful for CI tuning)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model
=
resolve_model_name
(
backend
)
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"128"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"128"
))
min_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"1024"
))
min_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"1024"
))
...
@@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"backend"
,
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
]
,
BACKENDS
,
)
)
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
...
@@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
# For batch invariance, disable custom all-reduce to ensure deterministic
# For batch invariance, disable custom all-reduce to ensure deterministic
...
@@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"backend"
,
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
]
,
BACKENDS
,
)
)
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
...
@@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
...
@@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
Useful for quick smoke testing and debugging.
Useful for quick smoke testing and debugging.
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model
=
resolve_model_name
(
backend
)
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
...
@@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
...
@@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"backend"
,
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
]
,
BACKENDS
,
)
)
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_without_batch_invariance_should_fail
(
def
test_logprobs_without_batch_invariance_should_fail
(
...
@@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
"""
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
vllm_is_batch_invariant
.
cache_clear
()
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# CRITICAL: Disable batch invariance for this test
# CRITICAL: Disable batch invariance for this test
...
@@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"
\n
{
'='
*
80
}
"
)
...
@@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
1607e664
...
@@ -803,11 +803,11 @@ def override_envs_for_invariance():
...
@@ -803,11 +803,11 @@ def override_envs_for_invariance():
"FLASH_ATTN"
,
# best supported backend
"FLASH_ATTN"
,
# best supported backend
"FLASHINFER"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
,
"TRITON_MLA"
,
# Not yet supported MLA backends
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
]
]
if
curr_attn_backend
not
in
supported_backends
:
if
curr_attn_backend
not
in
supported_backends
:
warning
=
(
warning
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment