Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6afc28a9
Unverified
Commit
6afc28a9
authored
Oct 28, 2025
by
Wentao Ye
Committed by
GitHub
Oct 28, 2025
Browse files
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
141e6a05
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
208 additions
and
204 deletions
+208
-204
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+207
-203
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
6afc28a9
...
...
@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
@
pytest
.
fixture
(
autouse
=
True
)
def
enable_batch_invariant_mode
():
def
enable_batch_invariant_mode
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"1"
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"1"
)
yield
# Restore original value after test
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
...
...
@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
...
...
@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
...
...
@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
...
...
@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@
skip_unsupported
def
test_simple_generation
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
llm
=
LLM
(
...
...
@@ -481,9 +491,14 @@ def test_simple_generation():
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
def
test_logprobs_without_batch_invariance_should_fail
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
...
...
@@ -493,224 +508,214 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# CRITICAL: Disable batch invariance for this test
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"0"
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"0"
)
try
:
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
)
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
long_min
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"768"
))
long_max
=
int
(
os
.
getenv
(
"VLLM_MAX_PROMPT"
,
"2048"
))
prompts
:
list
[
str
]
=
[]
options
=
[
(
max
(
long_min
,
1536
),
max
(
long_max
,
3072
)),
# very long
(
max
(
1024
,
long_min
),
max
(
2048
,
long_max
)),
# long
(
256
,
512
),
# mid
(
10
,
20
),
# short
]
for
_
in
range
(
32
):
lo
,
hi
=
random
.
choice
(
options
)
prompts
.
append
(
_random_prompt
(
lo
,
hi
))
sp
=
SamplingParams
(
temperature
=
0.6
,
top_p
=
1.0
,
max_tokens
=
8
,
seed
=
1234
,
logprobs
=
5
,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
long_min
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"768"
))
long_max
=
int
(
os
.
getenv
(
"VLLM_MAX_PROMPT"
,
"2048"
))
prompts
:
list
[
str
]
=
[]
options
=
[
(
max
(
long_min
,
1536
),
max
(
long_max
,
3072
)),
# very long
(
max
(
1024
,
long_min
),
max
(
2048
,
long_max
)),
# long
(
256
,
512
),
# mid
(
10
,
20
),
# short
]
# BS=1: run prompts individually and collect logprobs per step.
print
(
"
\n
"
+
"="
*
80
)
print
(
"STARTING BS=1 RUNS (each prompt individually)"
)
print
(
"="
*
80
+
"
\n
"
)
for
_
in
range
(
32
):
lo
,
hi
=
random
.
choice
(
options
)
prompts
.
append
(
_random_prompt
(
lo
,
hi
))
bs1_logprobs_per_prompt
=
[]
bs1_tokens_per_prompt
=
[]
for
idx
,
p
in
enumerate
(
prompts
):
print
(
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
sp
=
SamplingParams
(
temperature
=
0.6
,
top_p
=
1.0
,
max_tokens
=
8
,
seed
=
1234
,
logprobs
=
5
,
)
# BS=1: run prompts individually and collect logprobs per step.
print
(
"
\n
"
+
"="
*
80
)
print
(
"STARTING BS=1 RUNS (each prompt individually)"
)
print
(
"="
*
80
+
"
\n
"
)
bs1_logprobs_per_prompt
=
[]
bs1_tokens_per_prompt
=
[]
for
idx
,
p
in
enumerate
(
prompts
):
print
(
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
)
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
assert
len
(
outs
)
==
1
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
outs
[
0
])
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
assert
len
(
outs
)
==
1
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
outs
[
0
])
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
bs1_tokens_per_prompt
.
append
(
token_ids
)
print
(
f
"[BS=1] Prompt
{
idx
}
generated tokens:
{
token_ids
}
"
)
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
print
(
"
\n
"
+
"="
*
80
)
print
(
f
"STARTING BS=
{
len
(
prompts
)
}
RUN (all prompts batched)"
)
print
(
"="
*
80
+
"
\n
"
)
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
bsN_logprobs_per_prompt
=
[]
bsN_tokens_per_prompt
=
[]
print
(
f
"
\n
[BS=
{
len
(
prompts
)
}
] Processing batched outputs..."
)
for
idx
,
o
in
enumerate
(
outs_batched
):
tokens
=
o
.
outputs
[
0
].
token_ids
if
o
.
outputs
else
"N/A"
print
(
f
"[BS=
{
len
(
prompts
)
}
] Prompt
{
idx
}
generated tokens:
{
tokens
}
"
)
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
o
)
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bsN_logprobs_per_prompt
.
append
(
step_logprobs
)
bsN_tokens_per_prompt
.
append
(
token_ids
)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found
=
[]
for
i
,
(
logprobs_bs1
,
logprobs_bsN
,
tokens_bs1
,
tokens_bsN
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bsN_logprobs_per_prompt
,
bs1_tokens_per_prompt
,
bsN_tokens_per_prompt
,
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
bs1_tokens_per_prompt
.
append
(
token_ids
)
print
(
f
"[BS=1] Prompt
{
idx
}
generated tokens:
{
token_ids
}
"
)
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
print
(
"
\n
"
+
"="
*
80
)
print
(
f
"STARTING BS=
{
len
(
prompts
)
}
RUN (all prompts batched)"
)
print
(
"="
*
80
+
"
\n
"
)
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
bsN_logprobs_per_prompt
=
[]
bsN_tokens_per_prompt
=
[]
print
(
f
"
\n
[BS=
{
len
(
prompts
)
}
] Processing batched outputs..."
)
for
idx
,
o
in
enumerate
(
outs_batched
):
tokens
=
o
.
outputs
[
0
].
token_ids
if
o
.
outputs
else
"N/A"
print
(
f
"[BS=
{
len
(
prompts
)
}
] Prompt
{
idx
}
generated tokens:
{
tokens
}
"
)
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
o
)
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
):
if
len
(
logprobs_bs1
)
!=
len
(
logprobs_bsN
):
reason
=
(
f
"Different number of steps:
{
len
(
logprobs_bs1
)
}
(BS=1) "
f
"vs
{
len
(
logprobs_bsN
)
}
(BS=N)"
)
bsN_logprobs_per_prompt
.
append
(
step_logprobs
)
bsN_tokens_per_prompt
.
append
(
token_ids
)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found
=
[]
for
i
,
(
logprobs_bs1
,
logprobs_bsN
,
tokens_bs1
,
tokens_bsN
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bsN_logprobs_per_prompt
,
bs1_tokens_per_prompt
,
bsN_tokens_per_prompt
,
)
):
if
len
(
logprobs_bs1
)
!=
len
(
logprobs_bsN
):
reason
=
(
f
"Different number of steps:
{
len
(
logprobs_bs1
)
}
(BS=1) "
f
"vs
{
len
(
logprobs_bsN
)
}
(BS=N)"
)
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"all"
,
"reason"
:
reason
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
# Check if tokens match first
if
tokens_bs1
!=
tokens_bsN
:
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"sampling"
,
"reason"
:
"Different tokens sampled"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bsN
)):
if
a
.
shape
!=
b
.
shape
:
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"all"
,
"reason"
:
reason
,
"step"
:
t
,
"reason"
:
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
break
# Check if tokens match first
if
tokens_bs1
!=
tokens_bsN
:
if
not
torch
.
equal
(
a
,
b
):
max_diff
=
torch
.
abs
(
a
-
b
).
max
().
item
()
print
(
f
"
\n
[EXPECTED DIVERGENCE FOUND] Prompt
{
i
}
, "
f
"Token
{
t
}
: max_diff=
{
max_diff
:.
6
e
}
"
)
bs1_tok
=
tokens_bs1
[
t
]
if
t
<
len
(
tokens_bs1
)
else
"N/A"
bsN_tok
=
tokens_bsN
[
t
]
if
t
<
len
(
tokens_bsN
)
else
"N/A"
print
(
f
" Token IDs: bs1=
{
bs1_tok
}
, bsN=
{
bsN_tok
}
"
)
print
(
f
" BS=1 logprob:
{
a
.
tolist
()
}
"
)
print
(
f
" BS=N logprob:
{
b
.
tolist
()
}
"
)
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"sampling"
,
"reason"
:
"Different tokens sampled
"
,
"step"
:
t
,
"reason"
:
f
"Bitwise mismatch (max_diff=
{
max_diff
:.
6
e
}
)
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bsN
)):
if
a
.
shape
!=
b
.
shape
:
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
t
,
"reason"
:
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
break
if
not
torch
.
equal
(
a
,
b
):
max_diff
=
torch
.
abs
(
a
-
b
).
max
().
item
()
print
(
f
"
\n
[EXPECTED DIVERGENCE FOUND] Prompt
{
i
}
, "
f
"Token
{
t
}
: max_diff=
{
max_diff
:.
6
e
}
"
)
bs1_tok
=
tokens_bs1
[
t
]
if
t
<
len
(
tokens_bs1
)
else
"N/A"
bsN_tok
=
tokens_bsN
[
t
]
if
t
<
len
(
tokens_bsN
)
else
"N/A"
print
(
f
" Token IDs: bs1=
{
bs1_tok
}
, bsN=
{
bsN_tok
}
"
)
print
(
f
" BS=1 logprob:
{
a
.
tolist
()
}
"
)
print
(
f
" BS=N logprob:
{
b
.
tolist
()
}
"
)
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
t
,
"reason"
:
f
"Bitwise mismatch (max_diff=
{
max_diff
:.
6
e
}
)"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
break
# Print summary
print
(
f
"
\n
{
'='
*
80
}
"
)
if
differences_found
:
success_msg
=
(
f
"✓ SUCCESS: Batch invariance is doing something! "
f
"Found
{
len
(
differences_found
)
}
/
{
len
(
prompts
)
}
prompts "
f
"with differences when batch invariance was DISABLED."
)
print
(
success_msg
)
print
(
f
"
{
'='
*
80
}
"
)
for
diff
in
differences_found
:
print
(
f
"
\n
Prompt
{
diff
[
'prompt_idx'
]
}
(step
{
diff
[
'step'
]
}
):"
)
print
(
f
" Reason:
{
diff
[
'reason'
]
}
"
)
print
(
f
" Preview:
{
diff
[
'prompt_preview'
]
}
..."
)
if
"bs1_tokens"
in
diff
:
print
(
f
" BS=1 tokens:
{
diff
[
'bs1_tokens'
]
}
"
)
if
"bsN_tokens"
in
diff
:
print
(
f
" BS=N tokens:
{
diff
[
'bsN_tokens'
]
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
# Test PASSES because we found differences (batch invariance matters!)
return
else
:
# Test FAILS because everything matched even without batch invariance
fail_msg
=
(
f
"✗ UNEXPECTED: All
{
len
(
prompts
)
}
prompts matched "
f
"between BS=1 and BS=N even with batch invariance DISABLED. "
f
"This suggests batch invariance might not be necessary, "
f
"or the test needs more sensitive prompts."
)
print
(
fail_msg
)
print
(
f
"
{
'='
*
80
}
\n
"
)
pytest
.
fail
(
fail_msg
)
break
finally
:
# Restore original value
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
# Print summary
print
(
f
"
\n
{
'='
*
80
}
"
)
if
differences_found
:
success_msg
=
(
f
"✓ SUCCESS: Batch invariance is doing something! "
f
"Found
{
len
(
differences_found
)
}
/
{
len
(
prompts
)
}
prompts "
f
"with differences when batch invariance was DISABLED."
)
print
(
success_msg
)
print
(
f
"
{
'='
*
80
}
"
)
for
diff
in
differences_found
:
print
(
f
"
\n
Prompt
{
diff
[
'prompt_idx'
]
}
(step
{
diff
[
'step'
]
}
):"
)
print
(
f
" Reason:
{
diff
[
'reason'
]
}
"
)
print
(
f
" Preview:
{
diff
[
'prompt_preview'
]
}
..."
)
if
"bs1_tokens"
in
diff
:
print
(
f
" BS=1 tokens:
{
diff
[
'bs1_tokens'
]
}
"
)
if
"bsN_tokens"
in
diff
:
print
(
f
" BS=N tokens:
{
diff
[
'bsN_tokens'
]
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
# Test PASSES because we found differences (batch invariance matters!)
return
else
:
# Test FAILS because everything matched even without batch invariance
fail_msg
=
(
f
"✗ UNEXPECTED: All
{
len
(
prompts
)
}
prompts matched "
f
"between BS=1 and BS=N even with batch invariance DISABLED. "
f
"This suggests batch invariance might not be necessary, "
f
"or the test needs more sensitive prompts."
)
print
(
fail_msg
)
print
(
f
"
{
'='
*
80
}
\n
"
)
pytest
.
fail
(
fail_msg
)
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
forked
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
def
test_decode_logprobs_match_prefill_logprobs
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that verifies decode logprobs match prefill logprobs.
...
...
@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
6afc28a9
...
...
@@ -753,13 +753,13 @@ def override_envs_for_invariance():
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
supported_backends
=
[
"FLASH_ATTN"
,
# best supported backend
"FLEX_ATTENTION"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
,
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
]
if
curr_attn_backend
not
in
supported_backends
:
warning
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment