Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2062c072
Unverified
Commit
2062c072
authored
Jun 30, 2025
by
Woosuk Kwon
Committed by
GitHub
Jun 30, 2025
Browse files
[Spec Decode] Refactor spec decoding into a separate function (#20238)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
1c50e100
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
33 deletions
+60
-33
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+60
-33
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
2062c072
...
@@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
,
aux_hidden_states
=
model_output
hidden_states
,
aux_hidden_states
=
model_output
else
:
else
:
hidden_states
=
model_output
hidden_states
=
model_output
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# TODO: Support overlapping mirco-batches
...
@@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
not
self
.
speculative_config
:
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
# Speculative decoding is not enabled.
spec_token_ids
=
None
spec_token_ids
=
None
elif
self
.
speculative_config
.
method
==
"ngram"
:
else
:
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
self
.
eplb_step
()
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
)
def
propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
sampled_token_ids
:
list
[
list
[
int
]],
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
generate
_draft_token_ids
(
spec_token_ids
=
self
.
propose_ngram
_draft_token_ids
(
valid_
sampled_token_ids
,
sampling_metadata
)
sampled_token_ids
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
max_gen_len
==
1
:
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
# The input to the target model does not include draft tokens.
hidden_states
=
sample_hidden_states
hidden_states
=
sample_hidden_states
else
:
else
:
indices
=
[]
indices
=
[]
offset
=
0
offset
=
0
for
num_draft
,
tokens
in
zip
(
for
num_draft
,
tokens
in
zip
(
spec_decode_metadata
.
num_draft_tokens
,
spec_decode_metadata
.
num_draft_tokens
,
valid_
sampled_token_ids
):
sampled_token_ids
):
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
offset
+=
num_draft
+
1
offset
+=
num_draft
+
1
indices
=
torch
.
tensor
(
indices
,
device
=
self
.
device
)
indices
=
torch
.
tensor
(
indices
,
device
=
sample_hidden_states
.
device
)
hidden_states
=
sample_hidden_states
[
indices
]
hidden_states
=
sample_hidden_states
[
indices
]
spec_token_ids
=
self
.
drafter
.
propose
(
spec_token_ids
=
self
.
drafter
.
propose
(
...
@@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
# TODO(woosuk): Refactor the loop.
next_token_ids
:
list
[
int
]
=
[]
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
valid_
sampled_token_ids
):
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
if
token_ids
:
if
token_ids
:
# Common case.
# Common case.
next_token_id
=
token_ids
[
-
1
]
next_token_id
=
token_ids
[
-
1
]
...
@@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
[:
num_scheduled_tokens
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
target_hidden_states
=
torch
.
cat
(
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
...
@@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Refactor this.
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_rejected_tokens
=
[
num_rejected_tokens
=
[
n
+
1
-
len
(
valid_
sampled_token_ids
[
i
])
if
n
>
0
else
0
n
+
1
-
len
(
sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
]
num_rejected_tokens_tensor
=
async_tensor_h2d
(
num_rejected_tokens_tensor
=
async_tensor_h2d
(
...
@@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens
,
num_tokens
,
)
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
target_hidden_states
=
torch
.
cat
(
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
...
@@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
self
.
eplb_step
()
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
)
def
kv_connector_no_forward
(
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
...
@@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
.
finished_req_ids
)
scheduler_output
.
finished_req_ids
)
return
None
,
None
return
None
,
None
def
generate
_draft_token_ids
(
def
propose_ngram
_draft_token_ids
(
self
,
self
,
sampled_token_ids
:
list
[
list
[
int
]],
sampled_token_ids
:
list
[
list
[
int
]],
sampling_metadata
:
SamplingMetadata
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
# TODO(woosuk): Optimize.
# TODO(woosuk): Optimize.
draft_token_ids
:
list
[
list
[
int
]]
=
[]
draft_token_ids
:
list
[
list
[
int
]]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment