Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaef2077
Commit
aaef2077
authored
Aug 26, 2025
by
zhuwenwen
Browse files
v1 engine eager tbo support mla attention
parent
ffbb211b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
17 deletions
+45
-17
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+2
-1
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+39
-14
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+4
-2
No files found.
vllm/two_batch_overlap/two_batch_overlap.py
View file @
aaef2077
...
@@ -58,6 +58,7 @@ class TwoBatchOverlap():
...
@@ -58,6 +58,7 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
aaef2077
...
@@ -8,6 +8,7 @@ from vllm.forward_context import set_forward_context
...
@@ -8,6 +8,7 @@ from vllm.forward_context import set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadataBuilder
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -223,28 +224,47 @@ def prepare_tbo_atten_metadata(
...
@@ -223,28 +224,47 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
kv_cache_group_spec
.
kv_cache_spec
,
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
,
metadata_builder
,
)
)
if
req_offset
>
0
:
if
req_offset
>
0
:
origin_block_table
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
origin_block_table
=
metadata_builder
.
block_table
.
block_table
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
=
\
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
attn_metadata_i
=
(
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
build
(
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
if
req_offset
>
0
:
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_prefill_tokens
=
_num_prefill_tokens_record
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
@@ -287,6 +307,10 @@ def tbo_split_and_execute_model(
...
@@ -287,6 +307,10 @@ def tbo_split_and_execute_model(
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
use_tbo
=
False
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
...
@@ -318,7 +342,8 @@ def tbo_split_and_execute_model(
...
@@ -318,7 +342,8 @@ def tbo_split_and_execute_model(
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
model_output
=
runner
.
model
(
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
aaef2077
...
@@ -50,6 +50,7 @@ class TwoBatchOverlap():
...
@@ -50,6 +50,7 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
...
@@ -90,7 +91,8 @@ class TwoBatchOverlap():
...
@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
model_output
=
self
.
model_runner
.
model
(
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment