Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0936ee97
Commit
0936ee97
authored
Nov 28, 2025
by
jujl1
Browse files
feat: 主模型+mtp提前返回
parent
cd42bf87
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
96 deletions
+107
-96
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+19
-13
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+88
-83
No files found.
vllm/zero_overhead/v1/core.py
View file @
0936ee97
import
torch
from
collections
import
defaultdict
from
typing
import
Optional
...
...
@@ -89,9 +88,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
request
.
_output_token_ids
[:]
=
request
.
_output_token_ids
[:
requsets_valid_token_len
[
req_id
]]
request
.
_all_token_ids
[:]
=
request
.
_all_token_ids
[:
request
.
num_prompt_tokens
+
requsets_valid_token_len
[
req_id
]]
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
...
...
@@ -110,7 +110,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_idx
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
pooler_output
=
pooler_output
,
use_valid_token_len
=
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
...
...
@@ -191,7 +191,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
sampled_token_ids
else
[]
req_index
]
if
sampled_token_ids
else
[]
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
...
...
@@ -202,13 +203,18 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
num_new
=
len
(
generated_token_ids
)
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
and
request
.
num_computed_tokens
<=
request
.
num_prompt_tokens
+
num_new
):
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
request
.
num_computed_tokens
-=
num_tokens_rejected
spec_decoding_stats
=
scheduler
.
make_spec_decoding_stats
(
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
num_accepted_tokens
=
(
num_new
-
1
)
if
generated_token_ids
else
0
)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
...
...
@@ -231,7 +237,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
False
)
use_valid_token_len
=
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
...
...
@@ -242,8 +248,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if
model_runner_output
.
is_output_valid
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
False
)
pooler_output
,
use_valid_token_len
=
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
...
...
@@ -350,10 +356,10 @@ def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
model_output
=
core
.
execute_model
(
scheduler_output
)
if
isinstance
(
model_output
,
ZeroV1ModelRunnerOutput
):
engine_core_outputs
=
zero_overhead_update_from_output
(
core
.
scheduler
,
scheduler_output
,
model_output
)
# type: ignore
scheduler_output
,
model_output
)
# type: ignore
else
:
engine_core_outputs
=
core
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# type: ignore
return
(
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
)
\ No newline at end of file
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
0936ee97
from
typing
import
Any
,
Optional
,
Union
import
torch
import
numpy
as
np
...
...
@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_scheduler_max_num_tokens
=
0
self
.
fix_req_ids
=
None
self
.
fix_sampled_token_ids
=
None
if
hasattr
(
self
,
'drafter'
)
and
isinstance
(
self
.
drafter
,
EagleProposer
):
self
.
drafter
=
V1ZeroEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
...
...
@@ -81,6 +83,39 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
,
arange
=
self
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
# 等上一轮主模型结束
if
self
.
speculative_config
:
# 处理上一轮mtp
num_gen_tokens
=
self
.
last_sampler_host_tokens
.
shape
[
-
1
]
if
num_gen_tokens
==
1
:
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
else
:
# Includes spec decode tokens.
self
.
fix_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
self
.
last_sampler_host_tokens
,
self
.
input_batch
.
vocab_size
,
)
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
self
.
fix_sampled_token_ids
[
req_idx
].
clear
()
else
:
num_accepted_tokens
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
req_id
=
self
.
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
new_end_idx
=
start_idx
+
num_accepted_tokens
# # 更新token统计数据
self
.
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
num_tokens
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
new_end_idx
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
self
.
input_batch
.
num_computed_tokens_cpu
[
new_req_idx
]
-=
(
end_idx
-
new_end_idx
)
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
.
output_token_ids
.
extend
(
self
.
fix_sampled_token_ids
[
req_idx
])
# Get positions.
positions_np
=
self
.
positions_np
[:
total_num_scheduled_tokens
]
np
.
add
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_indices
],
...
...
@@ -267,15 +302,26 @@ class V1ZeroModelRunner(GPUModelRunner):
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
def
find_last_valid_vectorized
(
tensor
):
"""
向量化方法找到每行最后一个非-1元素
"""
mask
=
tensor
!=
-
1
reversed_mask
=
mask
.
flip
(
dims
=
[
1
])
# 沿着列方向反转
_
,
col_indices
=
torch
.
max
(
reversed_mask
.
int
(),
dim
=
1
)
original_col_indices
=
tensor
.
size
(
1
)
-
1
-
col_indices
result
=
tensor
[
torch
.
arange
(
tensor
.
size
(
0
)),
original_col_indices
]
all_invalid
=
~
mask
.
any
(
dim
=
1
)
result
[
all_invalid
]
=
-
1
# 或者设置为其他默认值
return
result
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
self
.
last_sampled_token_ids
is
not
None
:
sampled_tokens_num
=
self
.
last_sampled_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
sampled_tokens_num
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
update_req_indices
.
append
(
req_idx
)
input_ids_indices
.
append
(
token_idx
)
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -286,9 +332,9 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
True
)
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
for
i
in
range
(
sampled_tokens_num
):
input_ids
[
input_ids_indices_tensor
+
i
]
=
last_sampled_token_ids
[
update_req_indices_tensor
+
i
]
last_sampled_token_ids
=
find_last_valid_vectorized
(
self
.
last_sampled_token_ids
)
.
flatten
()
input_ids
[
input_ids_indices_tensor
]
=
last_sampled_token_ids
[
update_req_indices_tensor
]
def
propose_draft_token_ids
(
self
,
...
...
@@ -660,80 +706,19 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output
,
)
fix_req_ids
=
None
fix_sampled_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_req_ids
=
self
.
last_sampled_req_ids
is_output_valid
=
False
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
spec_sampler_event
.
record
()
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
if
self
.
speculative_config
:
self
.
spec_sampler_event
.
synchronize
()
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids_cpu
.
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids_cpu
,
self
.
input_batch
.
vocab_size
,
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
is_output_valid
=
True
else
:
# No spec decode tokens.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
continue
req_id
=
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
...
...
@@ -767,12 +752,32 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
cache_output_len
=
len
(
req_state
.
output_token_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
...
...
@@ -791,10 +796,10 @@ class V1ZeroModelRunner(GPUModelRunner):
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
fix_req_ids
=
fix_req_ids
,
fix_sampled_token_ids
=
fix_sampled_token_ids
,
fix_draft_tokens_ids
=
fix_draft_token_ids
,
fix_draft_req_ids
=
fix_draft_req_ids
,
fix_req_ids
=
self
.
fix_req_ids
,
fix_sampled_token_ids
=
self
.
fix_sampled_token_ids
,
fix_draft_tokens_ids
=
fix_draft_token_ids
,
fix_draft_req_ids
=
fix_draft_req_ids
,
is_output_valid
=
is_output_valid
)
return
model_output
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment