Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96197e48
Commit
96197e48
authored
Dec 23, 2025
by
jujl1
Browse files
fix: support chunk-prefill and fix bug in check_stop
parent
89639c96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
57 deletions
+45
-57
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+5
-3
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+40
-54
No files found.
vllm/zero_overhead/v1/core.py
View file @
96197e48
...
@@ -14,12 +14,13 @@ requsets_valid_token_len = {}
...
@@ -14,12 +14,13 @@ requsets_valid_token_len = {}
def
check_stop
(
request
:
Request
,
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_valid_token_len
:
bool
=
False
)
->
bool
:
use_valid_token_len
:
bool
=
False
,
last_token_offset
:
Optional
[
int
]
=
0
)
->
bool
:
if
use_valid_token_len
:
if
use_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
-
last_token_offset
else
:
else
:
valid_output_len
=
request
.
num_output_tokens
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
...
@@ -100,7 +101,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -100,7 +101,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
use_valid_token_len
=
True
,
last_token_offset
=
len
(
new_token_ids
)
-
num_new
)
if
stopped
:
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
96197e48
...
@@ -38,29 +38,22 @@ def fused_last_valid_scatter_kernel(
...
@@ -38,29 +38,22 @@ def fused_last_valid_scatter_kernel(
BLOCK_T
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
# indices
# indices
req_idx
=
tl
.
load
(
update_req_ptr
+
pid
)
req_idx
=
tl
.
load
(
update_req_ptr
+
pid
)
input_pos
=
tl
.
load
(
input_pos_ptr
+
pid
)
input_pos
=
tl
.
load
(
input_pos_ptr
+
pid
)
# load row
# load row
offs
=
tl
.
arange
(
0
,
BLOCK_T
)
offs
=
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
offs
<
T
mask
=
offs
<
T
row_ptr
=
last_ids_ptr
+
req_idx
*
stride0
+
offs
*
stride1
row_ptr
=
last_ids_ptr
+
req_idx
*
stride0
+
offs
*
stride1
vals
=
tl
.
load
(
row_ptr
,
mask
=
mask
,
other
=-
1
)
vals
=
tl
.
load
(
row_ptr
,
mask
=
mask
,
other
=-
1
)
# ✅ 正确做法:index reduction
idx
=
tl
.
where
(
vals
!=
-
1
,
offs
,
-
1
)
idx
=
tl
.
where
(
vals
!=
-
1
,
offs
,
-
1
)
last_idx
=
tl
.
max
(
idx
,
axis
=
0
)
last_idx
=
tl
.
max
(
idx
,
axis
=
0
)
# load last token
# load last token
last_val
=
tl
.
load
(
last_val
=
tl
.
load
(
last_ids_ptr
+
req_idx
*
stride0
+
last_idx
*
stride1
,
last_ids_ptr
+
req_idx
*
stride0
+
last_idx
*
stride1
,
mask
=
last_idx
>=
0
,
mask
=
last_idx
>=
0
,
other
=
0
,
other
=
0
,
)
)
# scatter
# scatter
tl
.
store
(
input_ids_ptr
+
input_pos
,
last_val
)
tl
.
store
(
input_ids_ptr
+
input_pos
,
last_val
)
...
@@ -138,23 +131,20 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -138,23 +131,20 @@ class V1ZeroModelRunner(GPUModelRunner):
)
)
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
num_accepted_tokens
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
self
.
fix_sampled_token_ids
[
req_idx
].
clear
()
req_id
=
self
.
fix_req_ids
[
req_idx
]
else
:
if
req_id
in
self
.
input_batch
.
req_ids
:
num_accepted_tokens
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
req_id
=
self
.
fix_req_ids
[
req_idx
]
new_end_idx
=
start_idx
+
num_accepted_tokens
if
req_id
in
self
.
input_batch
.
req_ids
:
# # 更新token统计数据
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
=
new_end_idx
new_end_idx
=
start_idx
+
num_accepted_tokens
self
.
input_batch
.
num_tokens
[
new_req_idx
]
=
new_end_idx
# # 更新token统计数据
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
new_end_idx
]
=
self
.
fix_sampled_token_ids
[
self
.
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
=
new_end_idx
req_idx
]
self
.
input_batch
.
num_tokens
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
num_computed_tokens_cpu
[
new_req_idx
]
-=
(
end_idx
-
new_end_idx
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
new_end_idx
]
=
self
.
fix_sampled_token_ids
[
if
req_id
in
self
.
requests
:
req_idx
]
req_state
=
self
.
requests
[
req_id
]
self
.
input_batch
.
num_computed_tokens_cpu
[
new_req_idx
]
-=
(
end_idx
-
new_end_idx
)
req_state
.
output_token_ids
.
extend
(
self
.
fix_sampled_token_ids
[
req_idx
])
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
.
output_token_ids
.
extend
(
self
.
fix_sampled_token_ids
[
req_idx
])
# Get positions.
# Get positions.
positions_np
=
self
.
positions_np
[:
total_num_scheduled_tokens
]
positions_np
=
self
.
positions_np
[:
total_num_scheduled_tokens
]
...
@@ -779,6 +769,31 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -779,6 +769,31 @@ class V1ZeroModelRunner(GPUModelRunner):
for
i
in
discard_sampled_tokens_req_indices
:
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
valid_sampled_token_ids
[
i
].
clear
()
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Cache the sampled tokens in the model runner, so that the scheduler
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
...
@@ -789,12 +804,9 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -789,12 +804,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
-
1
if
not
sampled_ids
:
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
continue
self
.
last_sampled_req_ids
.
append
(
req_id
)
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
end_idx
=
start_idx
+
len
(
sampled_ids
)
...
@@ -809,32 +821,6 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -809,32 +821,6 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment