Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6afc28a9
Unverified
Commit
6afc28a9
authored
Oct 28, 2025
by
Wentao Ye
Committed by
GitHub
Oct 28, 2025
Browse files
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
141e6a05
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
208 additions
and
204 deletions
+208
-204
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+207
-203
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
6afc28a9
...
@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
...
@@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
enable_batch_invariant_mode
():
def
enable_batch_invariant_mode
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Automatically enable batch invariant kernel overrides for all tests."""
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"1"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"1"
yield
yield
# Restore original value after test
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
...
@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
...
@@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
Ensures that the same request (the 'needle' prompt) yields identical output
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
...
@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
...
@@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# Allow overrides from environment (useful for CI tuning)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
...
@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
...
@@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
...
@@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@
skip_unsupported
@
skip_unsupported
def
test_simple_generation
():
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
Simple test that runs the model with a basic prompt and prints the output.
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
Useful for quick smoke testing and debugging.
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
llm
=
LLM
(
llm
=
LLM
(
...
@@ -481,9 +491,14 @@ def test_simple_generation():
...
@@ -481,9 +491,14 @@ def test_simple_generation():
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
],
)
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
def
test_logprobs_without_batch_invariance_should_fail
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
It DISABLES batch invariance mode and expects to see non-deterministic behavior
...
@@ -493,224 +508,214 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
...
@@ -493,224 +508,214 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
The test will PASS if we detect differences (proving batch invariance matters).
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
# CRITICAL: Disable batch invariance for this test
# CRITICAL: Disable batch invariance for this test
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"0"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"0"
try
:
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior"
)
print
(
"BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
)
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
long_min
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"768"
))
long_min
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"768"
))
long_max
=
int
(
os
.
getenv
(
"VLLM_MAX_PROMPT"
,
"2048"
))
long_max
=
int
(
os
.
getenv
(
"VLLM_MAX_PROMPT"
,
"2048"
))
prompts
:
list
[
str
]
=
[]
prompts
:
list
[
str
]
=
[]
options
=
[
options
=
[
(
max
(
long_min
,
1536
),
max
(
long_max
,
3072
)),
# very long
(
max
(
long_min
,
1536
),
max
(
long_max
,
3072
)),
# very long
(
max
(
1024
,
long_min
),
max
(
2048
,
long_max
)),
# long
(
max
(
1024
,
long_min
),
max
(
2048
,
long_max
)),
# long
(
256
,
512
),
# mid
(
256
,
512
),
# mid
(
10
,
20
),
# short
(
10
,
20
),
# short
]
]
for
_
in
range
(
32
):
lo
,
hi
=
random
.
choice
(
options
)
prompts
.
append
(
_random_prompt
(
lo
,
hi
))
sp
=
SamplingParams
(
temperature
=
0.6
,
top_p
=
1.0
,
max_tokens
=
8
,
seed
=
1234
,
logprobs
=
5
,
)
# BS=1: run prompts individually and collect logprobs per step.
for
_
in
range
(
32
):
print
(
"
\n
"
+
"="
*
80
)
lo
,
hi
=
random
.
choice
(
options
)
print
(
"STARTING BS=1 RUNS (each prompt individually)"
)
prompts
.
append
(
_random_prompt
(
lo
,
hi
))
print
(
"="
*
80
+
"
\n
"
)
bs1_logprobs_per_prompt
=
[]
sp
=
SamplingParams
(
bs1_tokens_per_prompt
=
[]
temperature
=
0.6
,
for
idx
,
p
in
enumerate
(
prompts
):
top_p
=
1.0
,
print
(
max_tokens
=
8
,
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
seed
=
1234
,
logprobs
=
5
,
)
# BS=1: run prompts individually and collect logprobs per step.
print
(
"
\n
"
+
"="
*
80
)
print
(
"STARTING BS=1 RUNS (each prompt individually)"
)
print
(
"="
*
80
+
"
\n
"
)
bs1_logprobs_per_prompt
=
[]
bs1_tokens_per_prompt
=
[]
for
idx
,
p
in
enumerate
(
prompts
):
print
(
f
"
\n
[BS=1] Running prompt
{
idx
}
/
{
len
(
prompts
)
}
- Preview:
{
p
[:
80
]
}
..."
)
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
assert
len
(
outs
)
==
1
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
outs
[
0
])
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
)
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
assert
len
(
outs
)
==
1
bs1_tokens_per_prompt
.
append
(
token_ids
)
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
outs
[
0
])
print
(
f
"[BS=1] Prompt
{
idx
}
generated tokens:
{
token_ids
}
"
)
if
step_logprobs
is
None
:
pytest
.
skip
(
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
"Logits are not available on RequestOutput; "
print
(
"
\n
"
+
"="
*
80
)
"enable logprobs return to run this test."
print
(
f
"STARTING BS=
{
len
(
prompts
)
}
RUN (all prompts batched)"
)
)
print
(
"="
*
80
+
"
\n
"
)
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
bs1_tokens_per_prompt
.
append
(
token_ids
)
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
print
(
f
"[BS=1] Prompt
{
idx
}
generated tokens:
{
token_ids
}
"
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
bsN_logprobs_per_prompt
=
[]
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
bsN_tokens_per_prompt
=
[]
print
(
"
\n
"
+
"="
*
80
)
print
(
f
"STARTING BS=
{
len
(
prompts
)
}
RUN (all prompts batched)"
)
print
(
f
"
\n
[BS=
{
len
(
prompts
)
}
] Processing batched outputs..."
)
print
(
"="
*
80
+
"
\n
"
)
for
idx
,
o
in
enumerate
(
outs_batched
):
tokens
=
o
.
outputs
[
0
].
token_ids
if
o
.
outputs
else
"N/A"
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
print
(
f
"[BS=
{
len
(
prompts
)
}
] Prompt
{
idx
}
generated tokens:
{
tokens
}
"
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
o
)
bsN_logprobs_per_prompt
=
[]
if
step_logprobs
is
None
:
bsN_tokens_per_prompt
=
[]
pytest
.
skip
(
"Logits are not available on RequestOutput; "
print
(
f
"
\n
[BS=
{
len
(
prompts
)
}
] Processing batched outputs..."
)
"enable logprobs return to run this test."
for
idx
,
o
in
enumerate
(
outs_batched
):
tokens
=
o
.
outputs
[
0
].
token_ids
if
o
.
outputs
else
"N/A"
print
(
f
"[BS=
{
len
(
prompts
)
}
] Prompt
{
idx
}
generated tokens:
{
tokens
}
"
)
step_logprobs
,
token_ids
=
_extract_step_logprobs
(
o
)
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bsN_logprobs_per_prompt
.
append
(
step_logprobs
)
bsN_tokens_per_prompt
.
append
(
token_ids
)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found
=
[]
for
i
,
(
logprobs_bs1
,
logprobs_bsN
,
tokens_bs1
,
tokens_bsN
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bsN_logprobs_per_prompt
,
bs1_tokens_per_prompt
,
bsN_tokens_per_prompt
,
)
)
):
bsN_logprobs_per_prompt
.
append
(
step_logprobs
)
if
len
(
logprobs_bs1
)
!=
len
(
logprobs_bsN
):
bsN_tokens_per_prompt
.
append
(
token_ids
)
reason
=
(
f
"Different number of steps:
{
len
(
logprobs_bs1
)
}
(BS=1) "
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
f
"vs
{
len
(
logprobs_bsN
)
}
(BS=N)"
differences_found
=
[]
)
for
i
,
(
logprobs_bs1
,
logprobs_bsN
,
tokens_bs1
,
tokens_bsN
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bsN_logprobs_per_prompt
,
bs1_tokens_per_prompt
,
bsN_tokens_per_prompt
,
)
):
if
len
(
logprobs_bs1
)
!=
len
(
logprobs_bsN
):
reason
=
(
f
"Different number of steps:
{
len
(
logprobs_bs1
)
}
(BS=1) "
f
"vs
{
len
(
logprobs_bsN
)
}
(BS=N)"
)
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"all"
,
"reason"
:
reason
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
# Check if tokens match first
if
tokens_bs1
!=
tokens_bsN
:
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
"sampling"
,
"reason"
:
"Different tokens sampled"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
continue
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bsN
)):
if
a
.
shape
!=
b
.
shape
:
differences_found
.
append
(
differences_found
.
append
(
{
{
"prompt_idx"
:
i
,
"prompt_idx"
:
i
,
"step"
:
"all"
,
"step"
:
t
,
"reason"
:
reason
,
"reason"
:
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
"bsN_tokens"
:
tokens_bsN
,
}
}
)
)
continue
break
# Check if tokens match first
if
not
torch
.
equal
(
a
,
b
):
if
tokens_bs1
!=
tokens_bsN
:
max_diff
=
torch
.
abs
(
a
-
b
).
max
().
item
()
print
(
f
"
\n
[EXPECTED DIVERGENCE FOUND] Prompt
{
i
}
, "
f
"Token
{
t
}
: max_diff=
{
max_diff
:.
6
e
}
"
)
bs1_tok
=
tokens_bs1
[
t
]
if
t
<
len
(
tokens_bs1
)
else
"N/A"
bsN_tok
=
tokens_bsN
[
t
]
if
t
<
len
(
tokens_bsN
)
else
"N/A"
print
(
f
" Token IDs: bs1=
{
bs1_tok
}
, bsN=
{
bsN_tok
}
"
)
print
(
f
" BS=1 logprob:
{
a
.
tolist
()
}
"
)
print
(
f
" BS=N logprob:
{
b
.
tolist
()
}
"
)
differences_found
.
append
(
differences_found
.
append
(
{
{
"prompt_idx"
:
i
,
"prompt_idx"
:
i
,
"step"
:
"sampling"
,
"step"
:
t
,
"reason"
:
"Different tokens sampled
"
,
"reason"
:
f
"Bitwise mismatch (max_diff=
{
max_diff
:.
6
e
}
)
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
"bsN_tokens"
:
tokens_bsN
,
}
}
)
)
continue
break
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bsN
)):
if
a
.
shape
!=
b
.
shape
:
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
t
,
"reason"
:
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
break
if
not
torch
.
equal
(
a
,
b
):
max_diff
=
torch
.
abs
(
a
-
b
).
max
().
item
()
print
(
f
"
\n
[EXPECTED DIVERGENCE FOUND] Prompt
{
i
}
, "
f
"Token
{
t
}
: max_diff=
{
max_diff
:.
6
e
}
"
)
bs1_tok
=
tokens_bs1
[
t
]
if
t
<
len
(
tokens_bs1
)
else
"N/A"
bsN_tok
=
tokens_bsN
[
t
]
if
t
<
len
(
tokens_bsN
)
else
"N/A"
print
(
f
" Token IDs: bs1=
{
bs1_tok
}
, bsN=
{
bsN_tok
}
"
)
print
(
f
" BS=1 logprob:
{
a
.
tolist
()
}
"
)
print
(
f
" BS=N logprob:
{
b
.
tolist
()
}
"
)
differences_found
.
append
(
{
"prompt_idx"
:
i
,
"step"
:
t
,
"reason"
:
f
"Bitwise mismatch (max_diff=
{
max_diff
:.
6
e
}
)"
,
"prompt_preview"
:
prompts
[
i
][:
100
],
"bs1_tokens"
:
tokens_bs1
,
"bsN_tokens"
:
tokens_bsN
,
}
)
break
# Print summary
print
(
f
"
\n
{
'='
*
80
}
"
)
if
differences_found
:
success_msg
=
(
f
"✓ SUCCESS: Batch invariance is doing something! "
f
"Found
{
len
(
differences_found
)
}
/
{
len
(
prompts
)
}
prompts "
f
"with differences when batch invariance was DISABLED."
)
print
(
success_msg
)
print
(
f
"
{
'='
*
80
}
"
)
for
diff
in
differences_found
:
print
(
f
"
\n
Prompt
{
diff
[
'prompt_idx'
]
}
(step
{
diff
[
'step'
]
}
):"
)
print
(
f
" Reason:
{
diff
[
'reason'
]
}
"
)
print
(
f
" Preview:
{
diff
[
'prompt_preview'
]
}
..."
)
if
"bs1_tokens"
in
diff
:
print
(
f
" BS=1 tokens:
{
diff
[
'bs1_tokens'
]
}
"
)
if
"bsN_tokens"
in
diff
:
print
(
f
" BS=N tokens:
{
diff
[
'bsN_tokens'
]
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
# Test PASSES because we found differences (batch invariance matters!)
return
else
:
# Test FAILS because everything matched even without batch invariance
fail_msg
=
(
f
"✗ UNEXPECTED: All
{
len
(
prompts
)
}
prompts matched "
f
"between BS=1 and BS=N even with batch invariance DISABLED. "
f
"This suggests batch invariance might not be necessary, "
f
"or the test needs more sensitive prompts."
)
print
(
fail_msg
)
print
(
f
"
{
'='
*
80
}
\n
"
)
pytest
.
fail
(
fail_msg
)
finally
:
# Print summary
# Restore original value
print
(
f
"
\n
{
'='
*
80
}
"
)
if
old_value
is
None
:
if
differences_found
:
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
success_msg
=
(
else
:
f
"✓ SUCCESS: Batch invariance is doing something! "
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
f
"Found
{
len
(
differences_found
)
}
/
{
len
(
prompts
)
}
prompts "
f
"with differences when batch invariance was DISABLED."
)
print
(
success_msg
)
print
(
f
"
{
'='
*
80
}
"
)
for
diff
in
differences_found
:
print
(
f
"
\n
Prompt
{
diff
[
'prompt_idx'
]
}
(step
{
diff
[
'step'
]
}
):"
)
print
(
f
" Reason:
{
diff
[
'reason'
]
}
"
)
print
(
f
" Preview:
{
diff
[
'prompt_preview'
]
}
..."
)
if
"bs1_tokens"
in
diff
:
print
(
f
" BS=1 tokens:
{
diff
[
'bs1_tokens'
]
}
"
)
if
"bsN_tokens"
in
diff
:
print
(
f
" BS=N tokens:
{
diff
[
'bsN_tokens'
]
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
# Test PASSES because we found differences (batch invariance matters!)
return
else
:
# Test FAILS because everything matched even without batch invariance
fail_msg
=
(
f
"✗ UNEXPECTED: All
{
len
(
prompts
)
}
prompts matched "
f
"between BS=1 and BS=N even with batch invariance DISABLED. "
f
"This suggests batch invariance might not be necessary, "
f
"or the test needs more sensitive prompts."
)
print
(
fail_msg
)
print
(
f
"
{
'='
*
80
}
\n
"
)
pytest
.
fail
(
fail_msg
)
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
def
test_decode_logprobs_match_prefill_logprobs
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
Test that verifies decode logprobs match prefill logprobs.
Test that verifies decode logprobs match prefill logprobs.
...
@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
...
@@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
This ensures that the logprobs from decode are consistent with what
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
we would get if we ran prefill on each prefix.
"""
"""
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
6afc28a9
...
@@ -753,13 +753,13 @@ def override_envs_for_invariance():
...
@@ -753,13 +753,13 @@ def override_envs_for_invariance():
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
supported_backends
=
[
supported_backends
=
[
"FLASH_ATTN"
,
# best supported backend
"FLASH_ATTN"
,
# best supported backend
"FLEX_ATTENTION"
,
"FLASHINFER"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER_MLA"
,
"FLASHINFER_MLA"
,
"TRITON_MLA"
,
"TRITON_MLA"
,
# Not yet supported MLA backends
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
]
]
if
curr_attn_backend
not
in
supported_backends
:
if
curr_attn_backend
not
in
supported_backends
:
warning
=
(
warning
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment