Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc945a5a
Commit
bc945a5a
authored
Dec 24, 2025
by
jujl1
Browse files
fix: 解决同时处理prefill和decode时的prefill请求token计数错误
parent
96197e48
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
114 additions
and
93 deletions
+114
-93
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+59
-0
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+6
-11
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+49
-82
No files found.
vllm/zero_overhead/utils.py
View file @
bc945a5a
...
@@ -4,6 +4,8 @@ from enum import Enum
...
@@ -4,6 +4,8 @@ from enum import Enum
import
os
import
os
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
triton
import
triton.language
as
tl
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
...
@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device):
...
@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device):
if
target_device
not
in
alloc_stream
.
keys
():
if
target_device
not
in
alloc_stream
.
keys
():
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
return
alloc_stream
[
target_device
]
return
alloc_stream
[
target_device
]
@
triton
.
jit
def
fused_last_valid_scatter_kernel
(
last_ids_ptr
,
# [B, T]
input_ids_ptr
,
# [N]
update_req_ptr
,
# [U]
input_pos_ptr
,
# [U]
stride0
,
stride1
,
T
,
BLOCK_T
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# indices
req_idx
=
tl
.
load
(
update_req_ptr
+
pid
)
input_pos
=
tl
.
load
(
input_pos_ptr
+
pid
)
# load row
offs
=
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
offs
<
T
row_ptr
=
last_ids_ptr
+
req_idx
*
stride0
+
offs
*
stride1
vals
=
tl
.
load
(
row_ptr
,
mask
=
mask
,
other
=-
1
)
idx
=
tl
.
where
(
vals
!=
-
1
,
offs
,
-
1
)
last_idx
=
tl
.
max
(
idx
,
axis
=
0
)
# load last token
last_val
=
tl
.
load
(
last_ids_ptr
+
req_idx
*
stride0
+
last_idx
*
stride1
,
mask
=
last_idx
>=
0
,
other
=
0
,
)
# scatter
tl
.
store
(
input_ids_ptr
+
input_pos
,
last_val
)
def
fused_update_input_ids_impl
(
last_sampled_token_ids
,
input_ids
,
update_req_indices
,
input_ids_indices
,
):
B
,
T
=
last_sampled_token_ids
.
shape
U
=
update_req_indices
.
numel
()
BLOCK_T
=
1024
assert
T
<=
BLOCK_T
grid
=
(
U
,)
fused_last_valid_scatter_kernel
[
grid
](
last_sampled_token_ids
,
input_ids
,
update_req_indices
,
input_ids_indices
,
last_sampled_token_ids
.
stride
(
0
),
last_sampled_token_ids
.
stride
(
1
),
T
,
BLOCK_T
=
BLOCK_T
,
)
\ No newline at end of file
vllm/zero_overhead/v1/core.py
View file @
bc945a5a
...
@@ -82,16 +82,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -82,16 +82,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
requsets_valid_token_len
[
req_id
]
+=
1
requsets_valid_token_len
[
req_id
]
+=
1
generated_token_ids
=
[
generated_token_ids
]
generated_token_ids
=
[
generated_token_ids
]
else
:
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
request
.
_output_token_ids
[:]
=
request
.
_output_token_ids
[:
requsets_valid_token_len
[
req_id
]]
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[:]
=
request
.
_all_token_ids
[:
request
.
num_prompt_tokens
+
requsets_valid_token_len
[
req_id
]]
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
stopped
=
False
stopped
=
False
new_logprobs
=
None
new_logprobs
=
None
...
@@ -194,7 +188,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -194,7 +188,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
sampled_token_ids
else
[]
req_index
]
if
sampled_token_ids
else
[]
if
request
.
num_computed_tokens
==
request
.
num_prompt_tokens
:
generated_token_ids
=
generated_token_ids
[:
1
]
scheduled_spec_token_ids
=
(
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
...
@@ -207,7 +202,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -207,7 +202,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new
=
len
(
generated_token_ids
)
num_new
=
len
(
generated_token_ids
)
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
and
request
.
num_computed_tokens
>
=
request
.
num_prompt_tokens
+
num_new
):
and
request
.
num_computed_tokens
>
request
.
num_prompt_tokens
+
num_new
):
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
bc945a5a
...
@@ -22,40 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
...
@@ -22,40 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.zero_overhead.utils
import
fused_update_input_ids_impl
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
fused_last_valid_scatter_kernel
(
last_ids_ptr
,
# [B, T]
input_ids_ptr
,
# [N]
update_req_ptr
,
# [U]
input_pos_ptr
,
# [U]
stride0
,
stride1
,
T
,
BLOCK_T
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# indices
req_idx
=
tl
.
load
(
update_req_ptr
+
pid
)
input_pos
=
tl
.
load
(
input_pos_ptr
+
pid
)
# load row
offs
=
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
offs
<
T
row_ptr
=
last_ids_ptr
+
req_idx
*
stride0
+
offs
*
stride1
vals
=
tl
.
load
(
row_ptr
,
mask
=
mask
,
other
=-
1
)
idx
=
tl
.
where
(
vals
!=
-
1
,
offs
,
-
1
)
last_idx
=
tl
.
max
(
idx
,
axis
=
0
)
# load last token
last_val
=
tl
.
load
(
last_ids_ptr
+
req_idx
*
stride0
+
last_idx
*
stride1
,
mask
=
last_idx
>=
0
,
other
=
0
,
)
# scatter
tl
.
store
(
input_ids_ptr
+
input_pos
,
last_val
)
class
V1ZeroModelRunner
(
GPUModelRunner
):
class
V1ZeroModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
,
device
):
def
__init__
(
self
,
vllm_config
,
device
):
...
@@ -116,21 +83,21 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -116,21 +83,21 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
,
arange
=
self
.
_get_cumsum_and_arange
(
cu_num_tokens
,
arange
=
self
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
num_scheduled_tokens
)
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
# 等上一轮主模型结束
if
self
.
speculative_config
:
# 处理上一轮mtp
num_gen_tokens
=
self
.
last_sampler_host_tokens
.
shape
[
-
1
]
if
num_gen_tokens
==
1
:
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
else
:
# Includes spec decode tokens.
self
.
fix_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
self
.
last_sampler_host_tokens
,
self
.
input_batch
.
vocab_size
,
)
if
self
.
speculative_config
and
self
.
last_sampler_host_tokens
!=
None
:
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
self
.
last_sampler_event
.
synchronize
()
# 等上一轮主模型结束
num_gen_tokens
=
self
.
last_sampler_host_tokens
.
shape
[
-
1
]
if
num_gen_tokens
==
1
:
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
else
:
self
.
fix_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
self
.
last_sampler_host_tokens
,
self
.
input_batch
.
vocab_size
,
)
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
continue
num_accepted_tokens
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
num_accepted_tokens
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
req_id
=
self
.
fix_req_ids
[
req_idx
]
req_id
=
self
.
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
if
req_id
in
self
.
input_batch
.
req_ids
:
...
@@ -332,31 +299,6 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -332,31 +299,6 @@ class V1ZeroModelRunner(GPUModelRunner):
True
)
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
def
fused_update_input_ids
(
last_sampled_token_ids
,
input_ids
,
update_req_indices
,
input_ids_indices
,
):
B
,
T
=
last_sampled_token_ids
.
shape
U
=
update_req_indices
.
numel
()
BLOCK_T
=
1024
assert
T
<=
BLOCK_T
grid
=
(
U
,)
fused_last_valid_scatter_kernel
[
grid
](
last_sampled_token_ids
,
input_ids
,
update_req_indices
,
input_ids_indices
,
last_sampled_token_ids
.
stride
(
0
),
last_sampled_token_ids
.
stride
(
1
),
T
,
BLOCK_T
=
BLOCK_T
,
)
update_req_indices
=
[]
update_req_indices
=
[]
input_ids_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
token_idx
=
0
...
@@ -374,13 +316,12 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -374,13 +316,12 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
self
.
device
,
True
)
True
)
fused_update_input_ids
(
if
self
.
speculative_config
:
self
.
last_sampled_token_ids
,
fused_update_input_ids_impl
(
self
.
last_sampled_token_ids
,
input_ids
,
input_ids
,
update_req_indices_tensor
,
input_ids_indices_tensor
)
update_req_indices_tensor
,
else
:
input_ids_indices_tensor
)
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
input_ids
[
input_ids_indices_tensor
]
=
last_sampled_token_ids
[
update_req_indices_tensor
]
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
self
,
self
,
...
@@ -757,7 +698,26 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -757,7 +698,26 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid
=
False
is_output_valid
=
False
# Get the valid generated tokens.
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
not
self
.
speculative_config
:
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
is
not
None
:
self
.
last_sampler_event
.
synchronize
()
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
continue
req_id
=
self
.
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
self
.
fix_req_ids
):
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
self
.
last_sampled_token_ids
=
None
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
...
@@ -804,10 +764,12 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -804,10 +764,12 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
cache_output_len
=
-
1
self
.
last_sampled_req_ids
.
append
(
req_id
)
if
not
sampled_ids
:
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
continue
self
.
last_sampled_req_ids
.
append
(
req_id
)
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
assert
end_idx
<=
self
.
max_model_len
,
(
...
@@ -820,6 +782,11 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -820,6 +782,11 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
if
not
self
.
speculative_config
and
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
cache_output_len
=
len
(
req_state
.
output_token_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment