Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6f8db81
Commit
f6f8db81
authored
Sep 03, 2025
by
lizhigong
Browse files
fix bugs in zero overhead and tbo
parent
14201006
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
64 deletions
+31
-64
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+27
-61
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+3
-2
vllm/zero_overhead/v1/outputs.py
vllm/zero_overhead/v1/outputs.py
+1
-1
No files found.
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
f6f8db81
...
@@ -25,6 +25,7 @@ class TBOModelInputSplit():
...
@@ -25,6 +25,7 @@ class TBOModelInputSplit():
self
.
req_num_right
=
0
self
.
req_num_right
=
0
self
.
scheduler_output_left
=
None
self
.
scheduler_output_left
=
None
self
.
scheduler_output_right
=
None
self
.
scheduler_output_right
=
None
self
.
query_start_loc_right
=
None
input_split
=
TBOModelInputSplit
()
input_split
=
TBOModelInputSplit
()
...
@@ -136,78 +137,39 @@ def prepare_tbo_atten_metadata(
...
@@ -136,78 +137,39 @@ def prepare_tbo_atten_metadata(
assert
num_reqs
>
0
assert
num_reqs
>
0
seq_len_offset
=
req_offset
seq_len_offset
=
req_offset
if
req_offset
==
0
:
#left
query_start_offset
=
0
else
:
query_start_offset
=
req_offset
+
1
# Get the number of scheduled tokens for each request.
# Get the number of scheduled tokens for each request.
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
max_num_scheduled_tokens
=
max
(
tokens
)
# Get request indices.
if
req_offset
>
0
:
#right
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
if
input_split
.
query_start_loc_right
==
None
:
req_indices
=
np
.
repeat
(
runner
.
arange_np
[:
num_reqs
],
# TODO: create when system init
num_scheduled_tokens
)
+
req_offset
input_split
.
query_start_loc_right
=
torch
.
zeros
(
runner
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
runner
.
device
)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
cu_num_tokens
,
arange
=
runner
.
_get_cumsum_and_arange
(
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
num_scheduled_tokens
)
cu_num_tokens
,
arange
=
runner
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
# Get positions.
# Prepare the attention metadata.
positions_np
=
runner
.
positions_np
[:
total_num_scheduled_tokens
]
runner
.
query_start_loc_np
[
0
]
=
0
np
.
add
(
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_indices
],
runner
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
arange
,
out
=
positions_np
)
# Calculate the slot mapping for each KV cache group.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
input_split
.
query_start_loc_right
[
0
:
num_reqs
+
1
].
copy_
(
runner
.
kv_cache_config
.
kv_cache_groups
):
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
block_size
=
kv_cache_group_spec
.
kv_cache_spec
.
block_size
# Note: pad query_start_loc to be non-decreasing, as kernels
block_table
:
BlockTable
=
runner
.
input_batch
.
block_table
[
# like FlashAttention requires that
kv_cache_group_id
]
input_split
.
query_start_loc_right
[
num_reqs
+
1
:].
fill_
(
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices
=
(
req_indices
*
block_table
.
max_num_blocks_per_req
+
positions_np
//
block_size
)
block_table_cpu
=
block_table
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
(
)[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
block_size
np
.
add
(
block_numbers
*
block_size
,
block_offsets
,
out
=
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
runner
.
query_start_loc_np
[
0
]
=
0
runner
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
runner
.
seq_lens_np
[:
num_reqs
]
=
(
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
:
req_offset
+
num_reqs
]
+
num_scheduled_tokens
)
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
].
copy_
(
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
if
req_offset
>
0
:
#right
runner
.
query_start_loc
[
query_start_offset
+
num_reqs
+
1
:].
fill_
(
runner
.
query_start_loc_cpu
[
num_reqs
].
item
())
runner
.
query_start_loc_cpu
[
num_reqs
].
item
())
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
].
copy_
(
runner
.
seq_lens_cpu
[:
num_reqs
],
query_start_loc
=
input_split
.
query_start_loc_right
[:
num_reqs
+
1
]
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache
if
req_offset
>
0
:
#right
runner
.
seq_lens
[
seq_len_offset
+
num_reqs
:].
fill_
(
0
)
query_start_loc
=
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
]
else
:
query_start_loc
=
runner
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
...
@@ -240,6 +202,9 @@ def prepare_tbo_atten_metadata(
...
@@ -240,6 +202,9 @@ def prepare_tbo_atten_metadata(
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
\
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_map_cpu
=
metadata_builder
.
block_table
.
slot_mapping_cpu
metadata_builder
.
block_table
.
slot_mapping_cpu
=
\
origin_slot_map_cpu
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_prefills_record
=
metadata_builder
.
_num_prefills
...
@@ -257,6 +222,7 @@ def prepare_tbo_atten_metadata(
...
@@ -257,6 +222,7 @@ def prepare_tbo_atten_metadata(
if
req_offset
>
0
:
if
req_offset
>
0
:
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping_cpu
=
origin_slot_map_cpu
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_decodes
=
_num_decodes_record
...
...
vllm/zero_overhead/v1/core.py
View file @
f6f8db81
...
@@ -80,6 +80,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -80,6 +80,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
requsets_valid_token_len
[
req_id
]
+=
1
generated_token_ids
=
[
generated_token_ids
]
else
:
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
if
valid_output_end
==
0
:
...
@@ -107,7 +108,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -107,7 +108,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
pooler_output
=
None
pooler_output
=
None
if
pooler_outputs
:
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_i
nde
x
]
pooler_output
=
pooler_outputs
[
req_i
d
x
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
pooler_output
,
True
)
if
stopped
:
if
stopped
:
...
@@ -118,7 +119,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -118,7 +119,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_i
nde
x
,
req_i
nde
x
+
1
)
new_logprobs
=
logprobs
.
slice
(
req_i
d
x
,
req_i
d
x
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
request
):
...
...
vllm/zero_overhead/v1/outputs.py
View file @
f6f8db81
...
@@ -9,6 +9,6 @@ class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
...
@@ -9,6 +9,6 @@ class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs]
# [num_reqs]
fix_req_ids
:
list
[
str
]
=
None
fix_req_ids
:
list
[
str
]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_draft_req_ids
:
list
[
list
[
int
]
]
=
None
fix_draft_req_ids
:
list
[
str
]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
is_output_valid
:
bool
=
True
is_output_valid
:
bool
=
True
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment