Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c80c53a3
Unverified
Commit
c80c53a3
authored
Aug 22, 2025
by
Nick Hill
Committed by
GitHub
Aug 23, 2025
Browse files
[BugFix] Fix batch updates for pooling models (#23398)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
24d0c9e6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
79 deletions
+95
-79
vllm/v1/sample/logits_processor/state.py
vllm/v1/sample/logits_processor/state.py
+16
-4
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+76
-70
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-5
No files found.
vllm/v1/sample/logits_processor/state.py
View file @
c80c53a3
...
...
@@ -50,6 +50,10 @@ class BatchUpdateBuilder:
self
.
added
=
added
or
[]
self
.
_is_removed_sorted
=
False
# Used to track changes in the pooling case
# where we don't populate the added list.
self
.
batch_changed
=
False
def
_ensure_removed_sorted
(
self
)
->
None
:
"""Sort removed request indices in
descending order.
...
...
@@ -80,6 +84,7 @@ class BatchUpdateBuilder:
raise
RuntimeError
(
"Cannot register new removed request after"
" self.removed has been read."
)
self
.
_removed
.
append
(
index
)
self
.
batch_changed
=
True
def
has_removed
(
self
)
->
bool
:
return
bool
(
self
.
_removed
)
...
...
@@ -98,9 +103,15 @@ class BatchUpdateBuilder:
return
self
.
_removed
.
pop
()
return
None
def
_is_update
(
self
)
->
bool
:
"""True if there is a batch state change"""
return
any
((
self
.
_removed
,
self
.
moved
,
self
.
added
))
def
reset
(
self
)
->
bool
:
"""Returns True if there were any changes to the batch."""
self
.
_is_removed_sorted
=
False
self
.
_removed
.
clear
()
self
.
moved
.
clear
()
self
.
added
.
clear
()
batch_changed
=
self
.
batch_changed
self
.
batch_changed
=
False
return
batch_changed
def
get_and_reset
(
self
,
batch_size
:
int
)
->
Optional
[
BatchUpdate
]:
"""Generate a logitsprocs batch update data structure and reset
...
...
@@ -114,7 +125,8 @@ class BatchUpdateBuilder:
"""
# Reset removal-sorting logic
self
.
_is_removed_sorted
=
False
if
not
self
.
_is_update
():
self
.
batch_changed
=
False
if
not
any
((
self
.
_removed
,
self
.
moved
,
self
.
added
)):
# No update; short-circuit
return
None
# Build batch state update
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
c80c53a3
...
...
@@ -65,8 +65,7 @@ class CachedRequestState:
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
if
idx
<
self
.
num_prompt_tokens
:
return
self
.
prompt_token_ids
[
idx
]
else
:
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
class
InputBatch
:
...
...
@@ -261,30 +260,27 @@ class InputBatch:
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert
request
.
sampling_params
# Fill the next empty index if there is one.
if
(
new_req_index
:
=
self
.
batch_update_builder
.
pop_removed
())
is
None
:
# Append to end otherwise.
new_req_index
=
self
.
num_reqs
assert
new_req_index
<
self
.
max_num_reqs
self
.
batch_update_builder
.
added
.
append
(
(
new_req_index
,
request
.
sampling_params
,
request
.
prompt_token_ids
,
request
.
output_token_ids
))
self
.
batch_update_builder
.
batch_changed
=
True
if
request
.
sampling_params
:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self
.
batch_update_builder
.
added
.
append
(
(
new_req_index
,
request
.
sampling_params
,
request
.
prompt_token_ids
,
request
.
output_token_ids
))
return
new_req_index
def
add_request
(
self
,
request
:
"CachedRequestState"
,
)
->
int
:
if
not
self
.
is_pooling_model
:
# New request index bookkeeping for autoregressive models.
req_index
=
self
.
_register_add_request
(
request
)
else
:
req_index
=
self
.
num_reqs
req_index
=
self
.
_register_add_request
(
request
)
req_id
=
request
.
req_id
if
req_index
==
len
(
self
.
_req_ids
):
...
...
@@ -389,7 +385,7 @@ class InputBatch:
self
.
logits_processing_needs_token_ids
[
req_index
]
=
(
pooling_params
.
requires_token_ids
)
else
:
raise
NotImplementedError
(
request
)
raise
NotImplementedError
(
"Unrecognized
request
type"
)
# Add request lora ID
if
request
.
lora_request
:
...
...
@@ -419,13 +415,25 @@ class InputBatch:
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_index
is
None
:
return
None
if
not
self
.
is_pooling_model
:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self
.
batch_update_builder
.
removed_append
(
req_index
)
self
.
batch_update_builder
.
removed_append
(
req_index
)
self
.
_req_ids
[
req_index
]
=
None
self
.
req_output_token_ids
[
req_index
]
=
None
# LoRA
lora_id
=
self
.
request_lora_mapping
[
req_index
]
if
lora_id
!=
0
:
lora_req_ids
=
self
.
lora_id_to_request_ids
[
lora_id
]
lora_req_ids
.
discard
(
req_id
)
if
not
lora_req_ids
:
del
self
.
lora_id_to_request_ids
[
lora_id
]
del
self
.
lora_id_to_lora_request
[
lora_id
]
self
.
request_lora_mapping
[
req_index
]
=
0
if
self
.
is_pooling_model
:
self
.
pooling_params
.
pop
(
req_id
,
None
)
return
req_index
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
...
...
@@ -439,29 +447,14 @@ class InputBatch:
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
# LoRA
lora_id
=
self
.
request_lora_mapping
[
req_index
]
if
lora_id
!=
0
:
lora_req_ids
=
self
.
lora_id_to_request_ids
[
lora_id
]
lora_req_ids
.
discard
(
req_id
)
if
not
lora_req_ids
:
del
self
.
lora_id_to_request_ids
[
lora_id
]
del
self
.
lora_id_to_lora_request
[
lora_id
]
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
has_allowed_token_ids
.
discard
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
self
.
bad_words_token_ids
.
pop
(
req_index
,
None
)
self
.
pooling_params
.
pop
(
req_id
,
None
)
return
req_index
def
swap_states
(
self
,
i1
:
int
,
i2
:
int
)
->
None
:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self
.
batch_update_builder
.
moved
.
append
(
(
i1
,
i2
,
MoveDirectionality
.
SWAP
))
old_id_i1
=
self
.
_req_ids
[
i1
]
old_id_i2
=
self
.
_req_ids
[
i2
]
self
.
_req_ids
[
i1
],
self
.
_req_ids
[
i2
]
=
\
...
...
@@ -479,18 +472,6 @@ class InputBatch:
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i2
],
self
.
top_p_cpu
[
i1
]
self
.
top_k_cpu
[
i1
],
self
.
top_k_cpu
[
i2
]
=
\
self
.
top_k_cpu
[
i2
],
self
.
top_k_cpu
[
i1
]
self
.
frequency_penalties_cpu
[
i1
],
self
.
frequency_penalties_cpu
[
i2
]
=
\
self
.
frequency_penalties_cpu
[
i2
],
self
.
frequency_penalties_cpu
[
i1
]
self
.
presence_penalties_cpu
[
i1
],
self
.
presence_penalties_cpu
[
i2
]
=
\
self
.
presence_penalties_cpu
[
i2
],
self
.
presence_penalties_cpu
[
i1
]
self
.
repetition_penalties_cpu
[
i1
],
self
.
repetition_penalties_cpu
[
i2
]
=
\
self
.
repetition_penalties_cpu
[
i2
],
self
.
repetition_penalties_cpu
[
i1
]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
...
...
@@ -501,18 +482,41 @@ class InputBatch:
self
.
token_ids_cpu
[
i1
,
...]
=
self
.
token_ids_cpu
[
i2
,
...]
self
.
token_ids_cpu
[
i2
,
...]
=
tmp
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
self
.
block_table
.
swap_row
(
i1
,
i2
)
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
if
self
.
is_pooling_model
:
# Sampling and logits parameters don't apply to pooling models.
return
# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self
.
batch_update_builder
.
moved
.
append
(
(
i1
,
i2
,
MoveDirectionality
.
SWAP
))
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i2
],
self
.
top_p_cpu
[
i1
]
self
.
top_k_cpu
[
i1
],
self
.
top_k_cpu
[
i2
]
=
\
self
.
top_k_cpu
[
i2
],
self
.
top_k_cpu
[
i1
]
self
.
frequency_penalties_cpu
[
i1
],
self
.
frequency_penalties_cpu
[
i2
]
=
\
self
.
frequency_penalties_cpu
[
i2
],
self
.
frequency_penalties_cpu
[
i1
]
self
.
presence_penalties_cpu
[
i1
],
self
.
presence_penalties_cpu
[
i2
]
=
\
self
.
presence_penalties_cpu
[
i2
],
self
.
presence_penalties_cpu
[
i1
]
self
.
repetition_penalties_cpu
[
i1
],
self
.
repetition_penalties_cpu
[
i2
]
=
\
self
.
repetition_penalties_cpu
[
i2
],
self
.
repetition_penalties_cpu
[
i1
]
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
]
=
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
)
->
None
:
"""Slide non-empty requests down into lower, empty indices.
...
...
@@ -529,12 +533,6 @@ class InputBatch:
"""
num_reqs
=
self
.
num_reqs
if
self
.
is_pooling_model
:
# Will be contiguous in pooling case, just trim the lists.
del
self
.
_req_ids
[
num_reqs
:]
del
self
.
req_output_token_ids
[
num_reqs
:]
return
if
not
(
empty_req_indices
:
=
self
.
batch_update_builder
.
removed
):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
...
...
@@ -562,11 +560,6 @@ class InputBatch:
# Move active request down into empty request
# index.
self
.
batch_update_builder
.
pop_removed
()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self
.
batch_update_builder
.
moved
.
append
(
(
last_req_index
,
empty_index
,
MoveDirectionality
.
UNIDIRECTIONAL
))
req_id
=
self
.
_req_ids
[
last_req_index
]
output_token_ids
=
self
.
req_output_token_ids
[
last_req_index
]
assert
req_id
is
not
None
...
...
@@ -587,6 +580,21 @@ class InputBatch:
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
if
self
.
is_pooling_model
:
last_req_index
-=
1
# Samping state not used by pooling models.
continue
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self
.
batch_update_builder
.
moved
.
append
(
(
last_req_index
,
empty_index
,
MoveDirectionality
.
UNIDIRECTIONAL
))
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
...
...
@@ -601,9 +609,6 @@ class InputBatch:
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
# TODO convert these to LogitsProcessors
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
...
...
@@ -626,8 +631,9 @@ class InputBatch:
"""Apply any batch updates to sampling metadata."""
if
self
.
is_pooling_model
:
# Batch changes every step for pooling models.
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
batch_changed
=
self
.
batch_update_builder
.
reset
()
if
batch_changed
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
return
# For non-pooling models - generate and apply logitsprocs update;
...
...
@@ -720,7 +726,8 @@ class InputBatch:
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
num_reqs
=
self
.
num_reqs
max_prompt_len
=
self
.
num_prompt_tokens
[:
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
...
...
@@ -728,11 +735,10 @@ class InputBatch:
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
for
i
in
range
(
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c80c53a3
...
...
@@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
raw_pooler_output
,
seq_lens_cpu
,
pooling_metadata
.
prompt_lens
):
if
seq_len
==
prompt_len
:
pooler_output
.
append
(
raw_output
.
data
)
else
:
pooler_output
.
append
(
None
)
output
=
raw_output
.
data
if
seq_len
==
prompt_len
else
None
pooler_output
.
append
(
output
)
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
...
...
@@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs.
(
attn_metadata
,
logits_indices
,
spec_decode_metadata
,
num_scheduled_tokens_np
,
spec_decode_common_attn_metadata
,
max_query_len
)
=
(
self
.
_prepare_inputs
(
scheduler_output
)
)
max_query_len
)
=
self
.
_prepare_inputs
(
scheduler_output
)
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment