Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ae9ce75
Commit
0ae9ce75
authored
Nov 21, 2025
by
jujl1
Browse files
feat: pp mtp加入零消耗调度,减少空泡
parent
d8ea775f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
148 additions
and
35 deletions
+148
-35
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+148
-35
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
0ae9ce75
...
...
@@ -88,7 +88,7 @@ else:
"xgrammar.kernels.apply_token_bitmask_inplace_torch_compile"
)
logger
=
init_logger
(
__name__
)
from
vllm.zero_overhead.v1.eagle
import
V1ZeroEagleProposer
class
GPUModelRunner
(
LoRAModelRunnerMixin
):
...
...
@@ -134,7 +134,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
spec_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_scheduler_max_num_tokens
=
0
# Model-related.
self
.
num_query_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
)
...
...
@@ -182,20 +183,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
# type: ignore
if
self
.
speculative_config
.
method
==
"eagle3"
:
self
.
use_aux_hidden_state_outputs
=
True
elif
self
.
speculative_config
.
method
==
"medusa"
:
self
.
drafter
=
MedusaProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
)
# type: ignore
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
drafter
=
V1ZeroEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
# if self.speculative_config.method == "ngram":
# self.drafter = NgramProposer(self.vllm_config)
# elif self.speculative_config.use_eagle():
# self.drafter = EagleProposer(self.vllm_config, self.device,
# self) # type: ignore
# if self.speculative_config.method == "eagle3":
# self.use_aux_hidden_state_outputs = True
# elif self.speculative_config.method == "medusa":
# self.drafter = MedusaProposer(
# vllm_config=self.vllm_config,
# device=self.device) # type: ignore
# else:
# raise ValueError("Unknown speculative decoding method: "
# f"{self.speculative_config.method}")
self
.
rejection_sampler
=
RejectionSampler
()
# Request states.
...
...
@@ -609,7 +612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
self
.
spec_scheduler_max_num_tokens
=
max_num_scheduled_tokens
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
...
...
@@ -1543,18 +1546,39 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
)
#-----------------------------------
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
else
:
self
.
spec_sampler_event
.
record
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
zero_propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
if
max_gen_len
==
1
:
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
_cpu
.
tolist
()
else
:
# Includes spec decode tokens.
self
.
spec_sampler_event
.
synchronize
()
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
sampled_token_ids
_cpu
,
self
.
input_batch
.
vocab_size
,
)
# Mask out the sampled tokens that should not be sampled.
...
...
@@ -1585,20 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
else
:
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
...
...
@@ -1619,6 +1629,109 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_nans_in_logits
=
num_nans_in_logits
,
)
def
zero_propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
num_accepted_tokens_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
propose_ngram_draft_token_ids
(
sampled_token_ids
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
# The input to the target model does not include draft tokens.
hidden_states
=
sample_hidden_states
else
:
indices
=
[]
offset
=
0
for
num_draft
,
tokens
in
zip
(
spec_decode_metadata
.
num_draft_tokens
,
sampled_token_ids
):
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
offset
+=
num_draft
+
1
indices
=
torch
.
tensor
(
indices
,
device
=
self
.
device
)
hidden_states
=
sample_hidden_states
[
indices
]
spec_token_ids
=
self
.
drafter
.
propose
(
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
row_indices
=
torch
.
arange
(
sampled_token_ids
.
size
(
0
),
device
=
sampled_token_ids
.
device
)
next_token_ids
=
sampled_token_ids
[
row_indices
,
num_accepted_tokens_tensor
].
flatten
()
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
block_table
=
eagle_attn_metadata
.
block_table
else
:
block_table
=
None
spec_scheduler_max_num_tokens
=
self
.
spec_scheduler_max_num_tokens
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_accepted_tokens_tensor
,
)
spec_scheduler_max_num_tokens
=
1
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
self
.
drafter
.
spec_scheduler_max_num_tokens
=
spec_scheduler_max_num_tokens
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
decoding
=
spec_decode_metadata
is
not
None
,
)
# spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
# self.last_draft_token_ids = draft_token_ids
# self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
# self.last_draft_event.record()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
def
propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment