Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eaf49786
Unverified
Commit
eaf49786
authored
Mar 22, 2026
by
Wentao Ye
Committed by
GitHub
Mar 22, 2026
Browse files
[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
77d24c4b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
27 deletions
+23
-27
tests/v1/determinism/test_batch_invariance.py
tests/v1/determinism/test_batch_invariance.py
+6
-9
tests/v1/determinism/test_online_batch_invariance.py
tests/v1/determinism/test_online_batch_invariance.py
+3
-4
tests/v1/determinism/utils.py
tests/v1/determinism/utils.py
+14
-14
No files found.
tests/v1/determinism/test_batch_invariance.py
View file @
eaf49786
...
@@ -8,10 +8,10 @@ import pytest
...
@@ -8,10 +8,10 @@ import pytest
import
torch
import
torch
from
utils
import
(
from
utils
import
(
BACKENDS
,
BACKENDS
,
TEST_MODEL
,
_extract_step_logprobs
,
_extract_step_logprobs
,
_random_prompt
,
_random_prompt
,
is_device_capability_below_90
,
is_device_capability_below_90
,
resolve_model_name
,
skip_unsupported
,
skip_unsupported
,
)
)
...
@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
attention_config
=
{
"backend"
:
backend
}
attention_config
=
{
"backend"
:
backend
}
# Allow overrides from environment (useful for CI tuning)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
resolve_model_name
(
backend
)
model
=
TEST_MODEL
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"128"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"128"
))
min_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"1024"
))
min_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"1024"
))
...
@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
):
):
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
# For batch invariance, disable custom all-reduce to ensure deterministic
# For batch invariance, disable custom all-reduce to ensure deterministic
...
@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print
(
f
"
{
'='
*
80
}
\n
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
TEST_MODEL
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
max_num_seqs
=
128
,
max_num_seqs
=
128
,
max_model_len
=
8192
,
max_model_len
=
8192
,
...
@@ -395,7 +394,7 @@ def test_simple_generation(backend):
...
@@ -395,7 +394,7 @@ def test_simple_generation(backend):
Simple test that runs the model with a basic prompt and prints the output.
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
Useful for quick smoke testing and debugging.
"""
"""
model
=
resolve_model_name
(
backend
)
model
=
TEST_MODEL
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
...
@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
monkeypatch
.
setattr
(
batch_invariant
,
"VLLM_BATCH_INVARIANT"
,
False
)
monkeypatch
.
setattr
(
batch_invariant
,
"VLLM_BATCH_INVARIANT"
,
False
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"
\n
{
'='
*
80
}
"
)
...
@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
print
(
f
"
{
'='
*
80
}
\n
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
TEST_MODEL
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
...
@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
"""
"""
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
resolve_model_name
(
backend
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
...
@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
print
(
f
"
{
'='
*
80
}
\n
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
TEST_MODEL
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
...
...
tests/v1/determinism/test_online_batch_invariance.py
View file @
eaf49786
...
@@ -17,7 +17,7 @@ from typing import Any
...
@@ -17,7 +17,7 @@ from typing import Any
import
openai
import
openai
import
pytest
import
pytest
from
utils
import
BACKENDS
,
_random_prompt
,
resolve_model_name
,
skip_unsupported
from
utils
import
BACKENDS
,
TEST_MODEL
,
_random_prompt
,
skip_unsupported
from
tests.utils
import
RemoteOpenAIServer
from
tests.utils
import
RemoteOpenAIServer
...
@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend
:
str
,
backend
:
str
,
)
->
None
:
)
->
None
:
random
.
seed
(
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
)))
random
.
seed
(
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
)))
model_name
=
resolve_model_name
(
backend
)
prompts_all
=
[
_random_prompt
(
10
,
50
)
for
_
in
range
(
32
)]
prompts_all
=
[
_random_prompt
(
10
,
50
)
for
_
in
range
(
32
)]
sp_kwargs
:
dict
[
str
,
Any
]
=
{
sp_kwargs
:
dict
[
str
,
Any
]
=
{
...
@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
if
tp_size
:
if
tp_size
:
server_args
+=
[
"-tp"
,
tp_size
]
server_args
+=
[
"-tp"
,
tp_size
]
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
server
:
with
RemoteOpenAIServer
(
TEST_MODEL
,
server_args
)
as
server
:
client
=
server
.
get_client
()
client
=
server
.
get_client
()
_compare_bs1_vs_bsn_single_process
(
_compare_bs1_vs_bsn_single_process
(
prompts
=
prompts_all
,
prompts
=
prompts_all
,
sp_kwargs
=
sp_kwargs
,
sp_kwargs
=
sp_kwargs
,
client
=
client
,
client
=
client
,
model_name
=
model_name
,
model_name
=
TEST_MODEL
,
)
)
tests/v1/determinism/utils.py
View file @
eaf49786
...
@@ -7,6 +7,10 @@ import pytest
...
@@ -7,6 +7,10 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.model_arch_config_convertor
import
(
ModelArchConfigConvertorBase
,
)
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_supports_mla
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_supports_mla
skip_unsupported
=
pytest
.
mark
.
skipif
(
skip_unsupported
=
pytest
.
mark
.
skipif
(
...
@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
...
@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
reason
=
"Requires CUDA and >= Ampere (SM80)"
,
reason
=
"Requires CUDA and >= Ampere (SM80)"
,
)
)
DEFAULT_MODEL
=
"Qwen/Qwen3-1.7B"
TEST_MODEL
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
DEFAULT_MODEL
)
BACKENDS
:
list
[
str
]
=
[
BACKENDS
:
list
[
str
]
=
[
"FLASH_ATTN"
,
"FLASH_ATTN"
,
"TRITON_ATTN"
,
"TRITON_ATTN"
,
"TRITON_MLA"
,
]
]
# FlashInfer temporarily disabled due to invariant CTA sizes.
# FlashInfer temporarily disabled due to invariant CTA sizes.
...
@@ -27,20 +33,14 @@ BACKENDS: list[str] = [
...
@@ -27,20 +33,14 @@ BACKENDS: list[str] = [
# if has_flashinfer():
# if has_flashinfer():
# BACKENDS.append("FLASHINFER")
# BACKENDS.append("FLASHINFER")
if
flash_attn_supports_mla
():
# only run MLA backends when the requested test model is itself an MLA model.
if
os
.
getenv
(
"VLLM_TEST_MODEL"
):
config
=
get_config
(
TEST_MODEL
,
trust_remote_code
=
False
)
if
ModelArchConfigConvertorBase
(
config
,
config
.
get_text_config
()).
is_deepseek_mla
():
BACKENDS
=
[
"TRITON_MLA"
]
if
flash_attn_supports_mla
():
BACKENDS
.
append
(
"FLASH_ATTN_MLA"
)
BACKENDS
.
append
(
"FLASH_ATTN_MLA"
)
DEFAULT_MODEL
=
"Qwen/Qwen3-1.7B"
MLA_MODEL
=
"deepseek-ai/DeepSeek-V2-Lite-Chat"
def
resolve_model_name
(
backend
:
str
)
->
str
:
"""Resolve the model name for the given backend."""
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
DEFAULT_MODEL
)
if
backend
.
endswith
(
"MLA"
)
and
model
==
DEFAULT_MODEL
:
return
MLA_MODEL
return
model
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
# Generate more realistic prompts that will actually produce varied tokens
# Generate more realistic prompts that will actually produce varied tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment