Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ddcbc2f3
Unverified
Commit
ddcbc2f3
authored
Oct 09, 2025
by
Nick Hill
Committed by
GitHub
Oct 09, 2025
Browse files
[Misc] Misc code simplifications (#26450)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
a83ff278
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
90 deletions
+79
-90
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+1
-1
vllm/v1/core/sched/utils.py
vllm/v1/core/sched/utils.py
+1
-2
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+8
-13
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+26
-30
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+9
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+34
-33
No files found.
vllm/v1/core/sched/scheduler.py
View file @
ddcbc2f3
...
@@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface):
affected_req_ids
.
add
(
request
.
request_id
)
affected_req_ids
.
add
(
request
.
request_id
)
return
(
affected_req_ids
,
total_affected_tokens
)
return
affected_req_ids
,
total_affected_tokens
def
_handle_invalid_blocks
(
self
,
invalid_block_ids
:
set
[
int
])
->
set
[
str
]:
def
_handle_invalid_blocks
(
self
,
invalid_block_ids
:
set
[
int
])
->
set
[
str
]:
total_requests_to_reschedule
=
0
total_requests_to_reschedule
=
0
...
...
vllm/v1/core/sched/utils.py
View file @
ddcbc2f3
...
@@ -59,8 +59,7 @@ def check_stop(
...
@@ -59,8 +59,7 @@ def check_stop(
sampling_params
=
request
.
sampling_params
sampling_params
=
request
.
sampling_params
assert
sampling_params
is
not
None
assert
sampling_params
is
not
None
min_tokens
=
sampling_params
.
min_tokens
if
request
.
num_output_tokens
<
sampling_params
.
min_tokens
:
if
request
.
num_output_tokens
<
min_tokens
:
return
False
return
False
last_token_id
=
request
.
output_token_ids
[
-
1
]
last_token_id
=
request
.
output_token_ids
[
-
1
]
...
...
vllm/v1/sample/rejection_sampler.py
View file @
ddcbc2f3
...
@@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
...
@@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
metadata
:
SpecDecodeMetadata
,
metadata
:
SpecDecodeMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
has_penalties
=
not
sampling_metadata
.
no_penalties
any_penalties_or_bad_words
=
(
any_penalties_or_bad_words
=
(
sampling_metadata
.
bad_words_token_ids
or
not
sampling_metadata
.
no
_penalties
sampling_metadata
.
bad_words_token_ids
or
has
_penalties
)
)
output_token_ids
=
sampling_metadata
.
output_token_ids
output_token_ids
=
sampling_metadata
.
output_token_ids
if
any_penalties_or_bad_words
:
if
any_penalties_or_bad_words
:
output_token_ids
=
self
.
_combine_outputs_with_spec_tokens
(
output_token_ids
=
self
.
_combine_outputs_with_spec_tokens
(
sampling_metadata
.
output_token_ids
,
output_token_ids
,
sampling_metadata
.
spec_token_ids
,
sampling_metadata
.
spec_token_ids
,
)
)
# Calculate indices of target logits.
# Calculate indices of target logits.
if
(
if
sampling_metadata
.
allowed_token_ids_mask
is
not
None
or
has_penalties
:
sampling_metadata
.
allowed_token_ids_mask
is
not
None
or
not
sampling_metadata
.
no_penalties
):
num_requests
=
len
(
sampling_metadata
.
output_token_ids
)
num_requests
=
len
(
sampling_metadata
.
output_token_ids
)
num_draft_tokens
=
torch
.
tensor
(
metadata
.
num_draft_tokens
,
device
=
"cpu"
)
num_draft_tokens
=
torch
.
tensor
(
metadata
.
num_draft_tokens
,
device
=
"cpu"
)
original_indices
=
torch
.
arange
(
num_requests
,
device
=
"cpu"
)
original_indices
=
torch
.
arange
(
num_requests
,
device
=
"cpu"
)
...
@@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
...
@@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
logits
.
masked_fill_
(
token_mask
,
float
(
"-inf"
))
logits
.
masked_fill_
(
token_mask
,
float
(
"-inf"
))
# Apply bad words exclusion.
# Apply bad words exclusion.
if
sampling_metadata
.
bad_words_token_ids
:
if
bad_words_token_ids
:
=
sampling_metadata
.
bad_words_token_ids
:
apply_bad_words_with_drafts
(
apply_bad_words_with_drafts
(
logits
,
logits
,
bad_words_token_ids
,
output_token_ids
,
metadata
.
num_draft_tokens
sampling_metadata
.
bad_words_token_ids
,
output_token_ids
,
metadata
.
num_draft_tokens
,
)
)
return
logits
return
logits
@
staticmethod
def
apply_penalties
(
def
apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
metadata
:
SpecDecodeMetadata
,
metadata
:
SpecDecodeMetadata
,
...
@@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
...
@@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
)
)
return
logits
return
logits
@
staticmethod
def
_combine_outputs_with_spec_tokens
(
def
_combine_outputs_with_spec_tokens
(
self
,
output_token_ids
:
list
[
list
[
int
]],
output_token_ids
:
list
[
list
[
int
]],
spec_token_ids
:
Optional
[
list
[
list
[
int
]]]
=
None
,
spec_token_ids
:
Optional
[
list
[
list
[
int
]]]
=
None
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
...
...
vllm/v1/sample/sampler.py
View file @
ddcbc2f3
...
@@ -120,8 +120,8 @@ class Sampler(nn.Module):
...
@@ -120,8 +120,8 @@ class Sampler(nn.Module):
)
)
return
sampler_output
return
sampler_output
@
staticmethod
def
apply_temperature
(
def
apply_temperature
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
all_random
:
bool
,
all_random
:
bool
,
...
@@ -132,7 +132,8 @@ class Sampler(nn.Module):
...
@@ -132,7 +132,8 @@ class Sampler(nn.Module):
temp
=
torch
.
where
(
temp
<
_SAMPLING_EPS
,
1.0
,
temp
)
temp
=
torch
.
where
(
temp
<
_SAMPLING_EPS
,
1.0
,
temp
)
return
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
return
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
def
greedy_sample
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
@
staticmethod
def
greedy_sample
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
return
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
def
sample
(
...
@@ -191,11 +192,12 @@ class Sampler(nn.Module):
...
@@ -191,11 +192,12 @@ class Sampler(nn.Module):
)
)
return
sampled
,
processed_logprobs
return
sampled
,
processed_logprobs
def
compute_logprobs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
@
staticmethod
def
compute_logprobs
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
@
staticmethod
def
gather_logprobs
(
def
gather_logprobs
(
self
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
num_logprobs
:
int
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
...
@@ -238,8 +240,8 @@ class Sampler(nn.Module):
...
@@ -238,8 +240,8 @@ class Sampler(nn.Module):
return
LogprobsTensors
(
indices
,
logprobs
,
token_ranks
)
return
LogprobsTensors
(
indices
,
logprobs
,
token_ranks
)
@
staticmethod
def
_combine_outputs_with_spec_tokens
(
def
_combine_outputs_with_spec_tokens
(
self
,
output_token_ids
:
list
[
list
[
int
]],
output_token_ids
:
list
[
list
[
int
]],
spec_token_ids
:
Optional
[
list
[
list
[
int
]]]
=
None
,
spec_token_ids
:
Optional
[
list
[
list
[
int
]]]
=
None
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
...
@@ -257,8 +259,9 @@ class Sampler(nn.Module):
...
@@ -257,8 +259,9 @@ class Sampler(nn.Module):
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
predict_bonus_token
:
bool
,
predict_bonus_token
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bad_words_token_ids
=
sampling_metadata
.
bad_words_token_ids
any_penalties_or_bad_words
=
(
any_penalties_or_bad_words
=
(
sampling_metadata
.
bad_words_token_ids
or
not
sampling_metadata
.
no_penalties
bool
(
bad_words_token_ids
)
or
not
sampling_metadata
.
no_penalties
)
)
output_token_ids
=
sampling_metadata
.
output_token_ids
output_token_ids
=
sampling_metadata
.
output_token_ids
...
@@ -266,7 +269,7 @@ class Sampler(nn.Module):
...
@@ -266,7 +269,7 @@ class Sampler(nn.Module):
# Combine base outputs with spec tokens when speculative decoding
# Combine base outputs with spec tokens when speculative decoding
# is enabled.
# is enabled.
output_token_ids
=
self
.
_combine_outputs_with_spec_tokens
(
output_token_ids
=
self
.
_combine_outputs_with_spec_tokens
(
sampling_metadata
.
output_token_ids
,
output_token_ids
,
sampling_metadata
.
spec_token_ids
,
sampling_metadata
.
spec_token_ids
,
)
)
...
@@ -275,14 +278,8 @@ class Sampler(nn.Module):
...
@@ -275,14 +278,8 @@ class Sampler(nn.Module):
logits
.
masked_fill_
(
sampling_metadata
.
allowed_token_ids_mask
,
float
(
"-inf"
))
logits
.
masked_fill_
(
sampling_metadata
.
allowed_token_ids_mask
,
float
(
"-inf"
))
# Apply bad words exclusion.
# Apply bad words exclusion.
if
sampling_metadata
.
bad_words_token_ids
:
if
bad_words_token_ids
:
apply_bad_words
(
apply_bad_words
(
logits
,
bad_words_token_ids
,
output_token_ids
)
logits
,
sampling_metadata
.
bad_words_token_ids
,
output_token_ids
if
output_token_ids
is
not
None
else
sampling_metadata
.
output_token_ids
,
)
# Apply logits processors which can impact greedy sampling.
# Apply logits processors which can impact greedy sampling.
for
processor
in
sampling_metadata
.
logitsprocs
.
non_argmax_invariant
:
for
processor
in
sampling_metadata
.
logitsprocs
.
non_argmax_invariant
:
...
@@ -292,22 +289,21 @@ class Sampler(nn.Module):
...
@@ -292,22 +289,21 @@ class Sampler(nn.Module):
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
,
output_token_ids
)
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
,
output_token_ids
)
return
logits
return
logits
@
staticmethod
def
apply_penalties
(
def
apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
output_token_ids
:
Optional
[
list
[
list
[
int
]]
]
=
None
,
output_token_ids
:
list
[
list
[
int
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
sampling_metadata
.
no_penalties
:
if
sampling_metadata
.
no_penalties
:
return
logits
assert
sampling_metadata
.
prompt_token_ids
is
not
None
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_all_penalties
(
return
apply_all_penalties
(
logits
,
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
output_token_ids
output_token_ids
,
if
output_token_ids
is
not
None
else
sampling_metadata
.
output_token_ids
,
)
)
return
logits
vllm/v1/worker/gpu_input_batch.py
View file @
ddcbc2f3
...
@@ -62,9 +62,8 @@ class CachedRequestState:
...
@@ -62,9 +62,8 @@ class CachedRequestState:
"provided via prompt_embeds, and its ID is unknown."
"provided via prompt_embeds, and its ID is unknown."
)
)
return
self
.
prompt_token_ids
[
idx
]
return
self
.
prompt_token_ids
[
idx
]
el
if
idx
-
self
.
num_prompt_tokens
<
len
(
self
.
output_token_ids
):
if
idx
-
self
.
num_prompt_tokens
<
len
(
self
.
output_token_ids
):
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
else
:
return
-
1
return
-
1
...
@@ -770,14 +769,13 @@ class InputBatch:
...
@@ -770,14 +769,13 @@ class InputBatch:
not
self
.
no_penalties
not
self
.
no_penalties
or
self
.
logits_processing_needs_token_ids
[:
num_reqs
].
any
()
or
self
.
logits_processing_needs_token_ids
[:
num_reqs
].
any
()
)
)
if
needs_prompt_token_ids
:
# The prompt tokens are used only for applying penalties or
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
# need penalties/step_pooler to be applied.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
prompt_token_ids
=
(
else
:
self
.
_make_prompt_token_ids_tensor
()
if
needs_prompt_token_ids
else
None
prompt_token_ids
=
None
)
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
if
not
self
.
no_allowed_token_ids
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ddcbc2f3
...
@@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Should be called after attention metadata creation. This just pads
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
# (num_tokens + padding)
def
pad_out_ubatch_slice
(
self
,
ubatch_slices
:
UBatchSlices
,
num_total_tokens
:
int
):
@
staticmethod
def
pad_out_ubatch_slice
(
ubatch_slices
:
UBatchSlices
,
num_total_tokens
:
int
):
padded_second_ubatch_slice
=
slice
(
padded_second_ubatch_slice
=
slice
(
ubatch_slices
[
1
].
token_slice
.
start
,
num_total_tokens
ubatch_slices
[
1
].
token_slice
.
start
,
num_total_tokens
)
)
...
@@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dict
[
str
,
Any
],
dict
[
str
,
Any
],
]:
]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
is_first_rank
=
get_pp_group
().
is_first_rank
# _prepare_inputs may reorder the batch, so we must gather multi
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
# modal outputs after that to ensure the correct order
if
(
if
(
self
.
supports_mm_inputs
self
.
supports_mm_inputs
and
get_pp_group
().
is_first_rank
and
is_first_rank
and
not
self
.
model_config
.
is_encoder_decoder
and
not
self
.
model_config
.
is_encoder_decoder
):
):
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
...
@@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**
self
.
_init_model_kwargs
(
num_scheduled_tokens
),
**
self
.
_init_model_kwargs
(
num_scheduled_tokens
),
**
self
.
_extract_mm_kwargs
(
scheduler_output
),
**
self
.
_extract_mm_kwargs
(
scheduler_output
),
}
}
elif
self
.
enable_prompt_embeds
and
get_pp_group
().
is_first_rank
:
elif
self
.
enable_prompt_embeds
and
is_first_rank
:
# Get the input embeddings for the tokens that are not input embeds,
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
# TODO(qthequartermasterman): Since even when prompt embeds are
...
@@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
else
:
positions
=
self
.
positions
.
gpu
[:
num_input_tokens
]
positions
=
self
.
positions
.
gpu
[:
num_input_tokens
]
if
get_pp_group
().
is_first_rank
:
if
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
else
:
else
:
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
...
@@ -2186,11 +2188,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2186,11 +2188,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample the next token and get logprobs if needed.
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
sampler_output
=
self
.
sampler
(
return
self
.
sampler
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
else
:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# logits tensor. This means any in-place operations on bonus_logits
...
@@ -2217,7 +2219,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2217,7 +2219,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
sampler_output
.
sampled_token_ids
=
output_token_ids
sampler_output
.
sampled_token_ids
=
output_token_ids
self
.
_update_states_after_model_execute
(
output_token_ids
)
self
.
_update_states_after_model_execute
(
output_token_ids
)
return
sampler_output
return
sampler_output
def
_bookkeeping_sync
(
def
_bookkeeping_sync
(
...
@@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
decode_cudagraph_batch_sizes
=
[
decode_cudagraph_batch_sizes
=
[
x
x
for
x
in
self
.
cudagraph_batch_sizes
for
x
in
self
.
cudagraph_batch_sizes
if
x
<=
max_num_tokens
and
x
>=
self
.
uniform_decode_query_len
if
max_num_tokens
>=
x
>=
self
.
uniform_decode_query_len
]
]
compilation_cases_decode
=
list
(
reversed
(
decode_cudagraph_batch_sizes
))
compilation_cases_decode
=
list
(
reversed
(
decode_cudagraph_batch_sizes
))
self
.
_capture_cudagraphs
(
self
.
_capture_cudagraphs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment