Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
411d255e
Commit
411d255e
authored
Dec 08, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev_mtp_sampler' into v0.9.2-dev
parents
18b4e6f3
33e33aa7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
8 deletions
+143
-8
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+126
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-4
No files found.
vllm/v1/worker/gpu_input_batch.py
View file @
411d255e
...
...
@@ -8,6 +8,7 @@ from typing import Optional, cast
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.logits_processor
import
(
BatchUpdateBuilder
,
LogitBiasLogitsProcessor
,
LogitsProcessorManager
,
MinPLogitsProcessor
,
MinTokensLogitsProcessor
,
MoveDirectionality
,
init_builtin_logitsprocs
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -192,6 +197,10 @@ class InputBatch:
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
set
[
str
]
=
set
()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self
.
_sampling_metadata_is_expanded
=
False
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
dtype
=
np
.
int32
)
...
...
@@ -593,7 +602,7 @@ class InputBatch:
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
def
refresh_metadata
(
self
):
def
refresh_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
...
...
@@ -602,8 +611,17 @@ class InputBatch:
batch_update
=
self
.
batch_update_builder
.
get_and_reset
(
self
.
num_reqs
)
for
logit_proc
in
self
.
logitsprocs
.
all
:
logit_proc
.
update_state
(
batch_update
)
if
batch_update
:
needs_rebuild
=
(
batch_update
or
repeat_counts
is
not
None
or
self
.
_sampling_metadata_is_expanded
)
if
needs_rebuild
:
if
repeat_counts
is
None
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
else
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata_expanded
(
repeat_counts
)
self
.
_sampling_metadata_is_expanded
=
repeat_counts
is
not
None
# Expanded metadata is built on demand; do not cache a copy here.
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
...
...
@@ -666,6 +684,105 @@ class InputBatch:
logitsprocs
=
self
.
logitsprocs
,
)
def
_make_sampling_metadata_expanded
(
self
,
repeat_counts
:
torch
.
Tensor
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
repeat_counts_cpu
=
repeat_counts
all_greedy
=
self
.
all_greedy
all_random
=
self
.
all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def
_expand_cpu_to_gpu
(
t
:
Optional
[
torch
.
Tensor
],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
if
t
is
None
:
return
None
base
=
t
[:
num_reqs
]
if
repeat_counts_cpu
is
not
None
:
base
=
base
.
repeat_interleave
(
repeat_counts_cpu
,
dim
=
0
)
return
base
.
to
(
device
=
self
.
device
,
dtype
=
dtype
if
dtype
is
not
None
else
None
,
non_blocking
=
True
)
needs_prompt_token_ids
=
(
not
self
.
no_penalties
or
(
self
.
num_reqs
>
0
and
self
.
logits_processing_needs_token_ids
))
if
needs_prompt_token_ids
:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
(
repeat_counts_cpu
)
else
:
prompt_token_ids
=
None
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu
=
None
if
self
.
no_top_p
else
self
.
top_p_cpu_tensor
top_k_cpu
=
None
if
self
.
no_top_k
else
self
.
top_k_cpu_tensor
repeat_list
=
repeat_counts_cpu
.
tolist
()
row_offsets
:
list
[
int
]
=
[]
total_rows
=
0
for
repeat
in
repeat_list
:
row_offsets
.
append
(
total_rows
)
total_rows
+=
int
(
repeat
)
expanded_output_token_ids
:
list
[
list
[
int
]]
=
[]
expanded_bad_words_token_ids
:
dict
[
int
,
list
[
list
[
int
]]]
=
{}
expanded_generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
row_idx
=
0
for
req_idx
in
range
(
num_reqs
):
repeat
=
int
(
repeat_list
[
req_idx
])
if
repeat
<=
0
:
continue
output_tokens
=
self
.
req_output_token_ids
[
req_idx
]
assert
output_tokens
is
not
None
bad_words
=
self
.
bad_words_token_ids
.
get
(
req_idx
)
generator
=
self
.
generators
.
get
(
req_idx
)
for
_
in
range
(
repeat
):
expanded_output_token_ids
.
append
(
output_tokens
)
if
bad_words
is
not
None
:
expanded_bad_words_token_ids
[
row_idx
]
=
bad_words
if
generator
is
not
None
:
expanded_generators
[
row_idx
]
=
generator
row_idx
+=
1
return
SamplingMetadata
(
temperature
=
_expand_cpu_to_gpu
(
None
if
all_greedy
else
self
.
temperature_cpu_tensor
),
all_greedy
=
all_greedy
,
all_random
=
all_random
,
top_p
=
_expand_cpu_to_gpu
(
top_p_cpu
),
top_k
=
_expand_cpu_to_gpu
(
top_k_cpu
,
dtype
=
torch
.
int32
),
generators
=
expanded_generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
frequency_penalties_cpu_tensor
)),
presence_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
presence_penalties_cpu_tensor
)),
repetition_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
repetition_penalties_cpu_tensor
)),
output_token_ids
=
expanded_output_token_ids
,
no_penalties
=
self
.
no_penalties
,
allowed_token_ids_mask
=
_expand_cpu_to_gpu
(
allowed_token_ids_mask
,
dtype
=
torch
.
bool
),
bad_words_token_ids
=
expanded_bad_words_token_ids
,
logitsprocs
=
self
.
logitsprocs
,
)
@
property
def
pooling_metadata
(
self
)
->
PoolingMetadata
:
if
len
(
self
.
pooling_params
)
==
0
:
...
...
@@ -685,7 +802,9 @@ class InputBatch:
pooling_params
=
pooling_params
,
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
def
_make_prompt_token_ids_tensor
(
self
,
repeat_counts_cpu
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
...
...
@@ -700,6 +819,9 @@ class InputBatch:
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
if
repeat_counts_cpu
is
not
None
:
prompt_token_ids_cpu_tensor
=
prompt_token_ids_cpu_tensor
\
.
repeat_interleave
(
repeat_counts_cpu
,
dim
=
0
)
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
411d255e
...
...
@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self
.
input_batch
.
condense
()
# Allow attention backend to reorder the batch, potentially
self
.
_may_reorder_batch
(
scheduler_output
)
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
\
scheduler_output
.
scheduled_spec_decode_tokens
:
num_reqs
=
self
.
input_batch
.
num_reqs
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
if
req_idx
is
not
None
:
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
repeat_counts
=
torch
.
from_numpy
(
num_draft_tokens
).
add_
(
1
)
self
.
input_batch
.
refresh_metadata
(
repeat_counts
)
def
_get_cumsum_and_arange
(
self
,
...
...
@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
#
sampling_metadata.all_greedy = True
#
sampling_metadata.all_random = False
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment