Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
411d255e
Commit
411d255e
authored
Dec 08, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev_mtp_sampler' into v0.9.2-dev
parents
18b4e6f3
33e33aa7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
8 deletions
+143
-8
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+126
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-4
No files found.
vllm/v1/worker/gpu_input_batch.py
View file @
411d255e
...
@@ -8,6 +8,7 @@ from typing import Optional, cast
...
@@ -8,6 +8,7 @@ from typing import Optional, cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm
import
envs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
...
@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.logits_processor
import
(
BatchUpdateBuilder
,
from
vllm.v1.sample.logits_processor
import
(
BatchUpdateBuilder
,
LogitBiasLogitsProcessor
,
LogitsProcessorManager
,
MinPLogitsProcessor
,
MinTokensLogitsProcessor
,
MoveDirectionality
,
MoveDirectionality
,
init_builtin_logitsprocs
)
init_builtin_logitsprocs
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -192,6 +197,10 @@ class InputBatch:
...
@@ -192,6 +197,10 @@ class InputBatch:
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
set
[
str
]
=
set
()
self
.
repetition_penalties_reqs
:
set
[
str
]
=
set
()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self
.
_sampling_metadata_is_expanded
=
False
# lora related
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
...
@@ -593,7 +602,7 @@ class InputBatch:
...
@@ -593,7 +602,7 @@ class InputBatch:
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
def
refresh_metadata
(
self
):
def
refresh_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Apply batch updates, reset input batch at end of step
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
* Apply batch add/remove/permute to logits procs' states
...
@@ -602,8 +611,17 @@ class InputBatch:
...
@@ -602,8 +611,17 @@ class InputBatch:
batch_update
=
self
.
batch_update_builder
.
get_and_reset
(
self
.
num_reqs
)
batch_update
=
self
.
batch_update_builder
.
get_and_reset
(
self
.
num_reqs
)
for
logit_proc
in
self
.
logitsprocs
.
all
:
for
logit_proc
in
self
.
logitsprocs
.
all
:
logit_proc
.
update_state
(
batch_update
)
logit_proc
.
update_state
(
batch_update
)
if
batch_update
:
needs_rebuild
=
(
batch_update
or
repeat_counts
is
not
None
or
self
.
_sampling_metadata_is_expanded
)
if
needs_rebuild
:
if
repeat_counts
is
None
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
else
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata_expanded
(
repeat_counts
)
self
.
_sampling_metadata_is_expanded
=
repeat_counts
is
not
None
# Expanded metadata is built on demand; do not cache a copy here.
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
num_reqs
=
self
.
num_reqs
...
@@ -666,6 +684,105 @@ class InputBatch:
...
@@ -666,6 +684,105 @@ class InputBatch:
logitsprocs
=
self
.
logitsprocs
,
logitsprocs
=
self
.
logitsprocs
,
)
)
def
_make_sampling_metadata_expanded
(
self
,
repeat_counts
:
torch
.
Tensor
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
repeat_counts_cpu
=
repeat_counts
all_greedy
=
self
.
all_greedy
all_random
=
self
.
all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def
_expand_cpu_to_gpu
(
t
:
Optional
[
torch
.
Tensor
],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
if
t
is
None
:
return
None
base
=
t
[:
num_reqs
]
if
repeat_counts_cpu
is
not
None
:
base
=
base
.
repeat_interleave
(
repeat_counts_cpu
,
dim
=
0
)
return
base
.
to
(
device
=
self
.
device
,
dtype
=
dtype
if
dtype
is
not
None
else
None
,
non_blocking
=
True
)
needs_prompt_token_ids
=
(
not
self
.
no_penalties
or
(
self
.
num_reqs
>
0
and
self
.
logits_processing_needs_token_ids
))
if
needs_prompt_token_ids
:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
(
repeat_counts_cpu
)
else
:
prompt_token_ids
=
None
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu
=
None
if
self
.
no_top_p
else
self
.
top_p_cpu_tensor
top_k_cpu
=
None
if
self
.
no_top_k
else
self
.
top_k_cpu_tensor
repeat_list
=
repeat_counts_cpu
.
tolist
()
row_offsets
:
list
[
int
]
=
[]
total_rows
=
0
for
repeat
in
repeat_list
:
row_offsets
.
append
(
total_rows
)
total_rows
+=
int
(
repeat
)
expanded_output_token_ids
:
list
[
list
[
int
]]
=
[]
expanded_bad_words_token_ids
:
dict
[
int
,
list
[
list
[
int
]]]
=
{}
expanded_generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
row_idx
=
0
for
req_idx
in
range
(
num_reqs
):
repeat
=
int
(
repeat_list
[
req_idx
])
if
repeat
<=
0
:
continue
output_tokens
=
self
.
req_output_token_ids
[
req_idx
]
assert
output_tokens
is
not
None
bad_words
=
self
.
bad_words_token_ids
.
get
(
req_idx
)
generator
=
self
.
generators
.
get
(
req_idx
)
for
_
in
range
(
repeat
):
expanded_output_token_ids
.
append
(
output_tokens
)
if
bad_words
is
not
None
:
expanded_bad_words_token_ids
[
row_idx
]
=
bad_words
if
generator
is
not
None
:
expanded_generators
[
row_idx
]
=
generator
row_idx
+=
1
return
SamplingMetadata
(
temperature
=
_expand_cpu_to_gpu
(
None
if
all_greedy
else
self
.
temperature_cpu_tensor
),
all_greedy
=
all_greedy
,
all_random
=
all_random
,
top_p
=
_expand_cpu_to_gpu
(
top_p_cpu
),
top_k
=
_expand_cpu_to_gpu
(
top_k_cpu
,
dtype
=
torch
.
int32
),
generators
=
expanded_generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
frequency_penalties_cpu_tensor
)),
presence_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
presence_penalties_cpu_tensor
)),
repetition_penalties
=
(
None
if
self
.
no_penalties
else
_expand_cpu_to_gpu
(
self
.
repetition_penalties_cpu_tensor
)),
output_token_ids
=
expanded_output_token_ids
,
no_penalties
=
self
.
no_penalties
,
allowed_token_ids_mask
=
_expand_cpu_to_gpu
(
allowed_token_ids_mask
,
dtype
=
torch
.
bool
),
bad_words_token_ids
=
expanded_bad_words_token_ids
,
logitsprocs
=
self
.
logitsprocs
,
)
@
property
@
property
def
pooling_metadata
(
self
)
->
PoolingMetadata
:
def
pooling_metadata
(
self
)
->
PoolingMetadata
:
if
len
(
self
.
pooling_params
)
==
0
:
if
len
(
self
.
pooling_params
)
==
0
:
...
@@ -685,7 +802,9 @@ class InputBatch:
...
@@ -685,7 +802,9 @@ class InputBatch:
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
)
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
def
_make_prompt_token_ids_tensor
(
self
,
repeat_counts_cpu
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
(
self
.
num_reqs
,
max_prompt_len
),
...
@@ -700,6 +819,9 @@ class InputBatch:
...
@@ -700,6 +819,9 @@ class InputBatch:
# token_id of this value.
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
for
i
in
range
(
self
.
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
if
repeat_counts_cpu
is
not
None
:
prompt_token_ids_cpu_tensor
=
prompt_token_ids_cpu_tensor
\
.
repeat_interleave
(
repeat_counts_cpu
,
dim
=
0
)
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
411d255e
...
@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self
.
input_batch
.
condense
()
self
.
input_batch
.
condense
()
# Allow attention backend to reorder the batch, potentially
# Allow attention backend to reorder the batch, potentially
self
.
_may_reorder_batch
(
scheduler_output
)
self
.
_may_reorder_batch
(
scheduler_output
)
# Refresh batch metadata with any pending updates.
# Refresh batch metadata with any pending updates. If we are in spec
self
.
input_batch
.
refresh_metadata
()
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
\
scheduler_output
.
scheduled_spec_decode_tokens
:
num_reqs
=
self
.
input_batch
.
num_reqs
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
if
req_idx
is
not
None
:
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
repeat_counts
=
torch
.
from_numpy
(
num_draft_tokens
).
add_
(
1
)
self
.
input_batch
.
refresh_metadata
(
repeat_counts
)
def
_get_cumsum_and_arange
(
def
_get_cumsum_and_arange
(
self
,
self
,
...
@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
)
sampler_output
.
sampled_token_ids
=
output_token_ids
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
else
:
sampling_metadata
.
all_greedy
=
True
#
sampling_metadata.all_greedy = True
sampling_metadata
.
all_random
=
False
#
sampling_metadata.all_random = False
sampler_output
=
self
.
sampler
(
sampler_output
=
self
.
sampler
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment