Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaef2077
Commit
aaef2077
authored
Aug 26, 2025
by
zhuwenwen
Browse files
v1 engine eager tbo support mla attention
parent
ffbb211b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
17 deletions
+45
-17
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+2
-1
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+39
-14
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+4
-2
No files found.
vllm/two_batch_overlap/two_batch_overlap.py
View file @
aaef2077
...
...
@@ -58,7 +58,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
aaef2077
...
...
@@ -8,6 +8,7 @@ from vllm.forward_context import set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadataBuilder
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -223,28 +224,47 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
,
metadata_builder
,
)
if
req_offset
>
0
:
origin_block_table
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
=
\
origin_block_table
=
metadata_builder
.
block_table
.
block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
build
(
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_prefill_tokens
=
_num_prefill_tokens_record
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
...
@@ -287,11 +307,15 @@ def tbo_split_and_execute_model(
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
use_tbo
:
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
...
...
@@ -318,7 +342,8 @@ def tbo_split_and_execute_model(
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
aaef2077
...
...
@@ -50,7 +50,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
...
...
@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment