Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1838cd48
Unverified
Commit
1838cd48
authored
Oct 04, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 04, 2025
Browse files
Revert "Add batch invariant kernel override for FlashInfer backend [2/n]" (#26220)
parent
7d6b0338
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
84 deletions
+29
-84
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+23
-40
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-12
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+5
-32
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
1838cd48
...
...
@@ -76,21 +76,18 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed.
- Keep max_tokens and max_model_len bounded for speed and memory use.
"""
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
12345
)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"128"
))
min_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MIN_PROMPT"
,
"1024"
))
max_random_prompt
=
int
(
os
.
getenv
(
"VLLM_MAX_PROMPT"
,
"2048"
))
assert
max_batch_size
>=
2
,
"Batch size should be >= 2 to mix needle."
batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"64"
))
assert
batch_size
>=
2
,
"Batch size should be >= 2 to mix needle."
# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util
=
float
(
os
.
getenv
(
"VLLM_GPU_MEMORY_UTILIZATION"
,
"0.
4
"
))
max_model_len
=
int
(
os
.
getenv
(
"VLLM_MAX_MODEL_LEN"
,
"
5120
"
))
gpu_mem_util
=
float
(
os
.
getenv
(
"VLLM_GPU_MEMORY_UTILIZATION"
,
"0.
3
"
))
max_model_len
=
int
(
os
.
getenv
(
"VLLM_MAX_MODEL_LEN"
,
"
4096
"
))
swap_space_gb
=
int
(
os
.
getenv
(
"VLLM_SWAP_SPACE_GB"
,
"4"
))
# Sampling parameters: longer outputs with a more random-sounding
...
...
@@ -114,7 +111,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with bs=1 behavior
llm_bs1
=
LLM_with_max_seqs
(
model
=
model
,
max_num_seqs
=
1
28
,
max_num_seqs
=
1
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
swap_space
=
swap_space_gb
,
...
...
@@ -129,7 +126,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with larger batch limit (e.g., 64)
llm_bsN
=
LLM_with_max_seqs
(
model
=
model
,
max_num_seqs
=
128
,
max_num_seqs
=
batch_size
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
swap_space
=
swap_space_gb
,
...
...
@@ -138,17 +135,15 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
mismatches
=
0
for
trial
in
range
(
num_trials
):
# Create a batch of size `
max_
batch_size` and insert the needle at
# Create a batch of size `batch_size` and insert the needle at
# a random index
prompts
:
list
[
str
]
=
[]
batch_size
=
random
.
randint
(
max_batch_size
//
2
,
max_batch_size
)
needle_pos
=
random
.
randint
(
0
,
batch_size
-
1
)
for
i
in
range
(
batch_size
):
if
i
==
needle_pos
:
prompts
.
append
(
needle_prompt
)
else
:
prompts
.
append
(
_random_prompt
(
min_random_prompt
,
max_random_prompt
))
prompts
.
append
(
_random_prompt
())
# Generate with the larger-batch engine
outputs
=
llm_bsN
.
generate
(
prompts
,
sampling
)
...
...
@@ -159,19 +154,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
text
=
needle_output
.
outputs
[
0
].
text
if
text
!=
baseline_text
:
print
(
f
"
{
text
}
\n\n
== Not the same as ==
\n\n
{
baseline_text
}
\n\n
"
)
mismatches
+=
1
passes
=
num_trials
-
mismatches
# Dump how many passed vs failed
print
(
f
"[determinism] total=
{
num_trials
}
, passed=
{
passes
}
, "
f
"failed=
{
mismatches
}
,
max_
batch_size=
{
max_
batch_size
}
"
)
f
"failed=
{
mismatches
}
, batch_size=
{
batch_size
}
"
)
if
mismatches
>
0
:
pytest
.
fail
(
f
"Nondeterministic outputs detected:
{
mismatches
}
failed out "
f
"of
{
num_trials
}
trials (
max_
batch_size=
{
max_
batch_size
}
)."
)
f
"of
{
num_trials
}
trials (batch_size=
{
batch_size
}
)."
)
finally
:
# Ensure engines are shutdown to free GPU/VRAM across test sessions
...
...
@@ -203,14 +196,9 @@ def _extract_step_logprobs(request_output):
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA to match production inference path."
,
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLEX_ATTENTION"
,
"FLASHINFER"
])
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bs2
():
backend
=
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
...
...
@@ -224,15 +212,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
prompts
=
[
"The capital of France is"
,
"The capital of Germany is"
,
_random_prompt
(
10
,
1024
),
_random_prompt
(
10
,
1024
),
_random_prompt
(
10
,
1024
),
_random_prompt
(
10
,
1024
),
_random_prompt
(
10
,
1024
),
]
sp
=
SamplingParams
(
temperature
=
0.
6
,
temperature
=
0.
0
,
top_p
=
1.0
,
max_tokens
=
8
,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
...
...
@@ -251,25 +234,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
"enable logprobs return to run this test."
)
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
# BS=
N
: run prompts in a batch and collect logprobs per step for each
# BS=
2
: run prompts in a batch and collect logprobs per step for each
# prompt.
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
bs
N
_logprobs_per_prompt
=
[]
bs
2
_logprobs_per_prompt
=
[]
for
o
in
outs_batched
:
step_logprobs
=
_extract_step_logprobs
(
o
)
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs
N
_logprobs_per_prompt
.
append
(
step_logprobs
)
bs
2
_logprobs_per_prompt
.
append
(
step_logprobs
)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=
N
runs.
for
i
,
(
logprobs_bs1
,
logprobs_bs
N
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bs
N
_logprobs_per_prompt
)):
assert
len
(
logprobs_bs1
)
==
len
(
logprobs_bs
N
),
(
# Compare step-by-step logprobs for each prompt between BS=1 and BS=
2
runs.
for
i
,
(
logprobs_bs1
,
logprobs_bs
2
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bs
2
_logprobs_per_prompt
)):
assert
len
(
logprobs_bs1
)
==
len
(
logprobs_bs
2
),
(
f
"Different number of generation steps for prompt index
{
i
}
: "
f
"
{
len
(
logprobs_bs1
)
}
(BS=1) vs
{
len
(
logprobs_bs
N
)
}
(BS=
N
)"
)
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bs
N
)):
f
"
{
len
(
logprobs_bs1
)
}
(BS=1) vs
{
len
(
logprobs_bs
2
)
}
(BS=
2
)"
)
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bs
2
)):
assert
a
.
shape
==
b
.
shape
,
(
f
"Logits shape mismatch at prompt
{
i
}
, step
{
t
}
: "
f
"
{
a
.
shape
}
vs
{
b
.
shape
}
"
)
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
1838cd48
...
...
@@ -8,12 +8,8 @@ from typing import Any, Union
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
logger
=
init_logger
(
__name__
)
def
_matmul_launch_metadata
(
grid
:
Callable
[...,
Any
],
kernel
:
Any
,
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
...
...
@@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant():
def
init_batch_invariance
():
# this will hit all the csrc overrides as well
if
vllm_kernel_override_batch_invariant
():
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
supported_backends
=
[
"FLEX_ATTENTION"
,
"FLASHINFER"
]
if
curr_attn_backend
not
in
supported_backends
:
warning
=
"Forcibly updating attention backend to"
\
f
"
{
supported_backends
[
0
]
}
for batch_invariant. "
\
f
" Supported backends:
{
supported_backends
}
."
logger
.
warning_once
(
warning
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
supported_backends
[
0
]
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
"FLEX_ATTENTION"
enable_batch_invariant_mode
()
vllm/v1/attention/backends/flashinfer.py
View file @
1838cd48
...
...
@@ -20,8 +20,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType
)
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_kernel_override_batch_invariant
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4Quant
)
from
vllm.platforms
import
current_platform
...
...
@@ -44,7 +42,6 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
from
vllm.v1.kv_cache_interface
import
AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
...
...
@@ -266,15 +263,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
self
.
_decode_wrapper
=
None
# Wrapper for decode (general shape)
if
vllm_kernel_override_batch_invariant
():
self
.
decode_fixed_split_size
=
2048
self
.
prefill_fixed_split_size
=
4096
self
.
disable_split_kv
=
True
else
:
self
.
decode_fixed_split_size
=
-
1
self
.
prefill_fixed_split_size
=
-
1
self
.
disable_split_kv
=
False
self
.
compilation_config
=
vllm_config
.
compilation_config
max_num_pages_per_req
=
cdiv
(
self
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
)
...
...
@@ -368,10 +356,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
buffer_size
=
FLASHINFER_WORKSPACE_BUFFER_SIZE
if
vllm_kernel_override_batch_invariant
():
buffer_size
=
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self
.
_workspace_buffer
=
torch
.
zeros
(
buffer_size
,
self
.
_workspace_buffer
=
torch
.
zeros
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
return
self
.
_workspace_buffer
...
...
@@ -629,8 +615,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
...
...
@@ -684,8 +668,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
decode_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
return
attn_metadata
...
...
@@ -1066,8 +1048,6 @@ def fast_plan_decode(
rope_scale
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
non_blocking
:
bool
=
True
,
fixed_split_size
:
int
=
-
1
,
disable_split_kv
:
bool
=
False
,
)
->
None
:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
...
...
@@ -1105,10 +1085,6 @@ def fast_plan_decode(
rope_scale
,
rope_theta
,
non_blocking
,
None
,
# block_tables
None
,
# seq_lens
fixed_split_size
,
disable_split_kv
,
)
self
.
vllm_first_call
=
False
return
...
...
@@ -1154,7 +1130,7 @@ def fast_plan_decode(
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
try
:
# Make sure we pass exactly 1
8
arguments for tensor core version
# Make sure we pass exactly 1
5
arguments for tensor core version
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
...
...
@@ -1171,9 +1147,6 @@ def fast_plan_decode(
head_dim
,
head_dim
,
False
,
# causal
window_left
,
fixed_split_size
,
disable_split_kv
,
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in tensor core plan:
{
e
}
"
)
from
e
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment