Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7124e74d
Commit
7124e74d
authored
May 26, 2025
by
lizhigong
Browse files
fix tbo support deepseek mtp
parent
7f022e4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
6 deletions
+31
-6
vllm/two_batch_overlap/model_input_split.py
vllm/two_batch_overlap/model_input_split.py
+9
-2
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+22
-4
No files found.
vllm/two_batch_overlap/model_input_split.py
View file @
7124e74d
...
@@ -98,6 +98,13 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -98,6 +98,13 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
previous_hidden_states_left
=
None
previous_hidden_states_right
=
None
if
model_input
.
previous_hidden_states
!=
None
:
split_previous_hidden_states
=
torch
.
split
(
model_input
.
previous_hidden_states
,
batch_size_split
,
dim
=
0
)
previous_hidden_states_left
=
split_previous_hidden_states
[
0
]
previous_hidden_states_right
=
split_previous_hidden_states
[
1
]
if
isinstance
(
model_input
.
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
model_input
.
attn_metadata
,
MLACommonMetadata
):
attn_metadata_left
=
MLACommonMetadata
(
attn_metadata_left
=
MLACommonMetadata
(
...
@@ -302,7 +309,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -302,7 +309,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine
=
model_input
.
virtual_engine
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
previous_hidden_states
=
previous_hidden_states
_left
,
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_left
,
seq_groups
=
seq_groups_left
,
selected_token_indices
=
selected_token_indices_left
,
selected_token_indices
=
selected_token_indices_left
,
...
@@ -330,7 +337,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -330,7 +337,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine
=
model_input
.
virtual_engine
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
previous_hidden_states
=
previous_hidden_states
_right
,
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_right
,
seq_groups
=
seq_groups_right
,
selected_token_indices
=
selected_token_indices_right
,
selected_token_indices
=
selected_token_indices_right
,
...
...
vllm/two_batch_overlap/two_batch_overlap.py
View file @
7124e74d
...
@@ -86,6 +86,11 @@ class TwoBatchOverlap():
...
@@ -86,6 +86,11 @@ class TwoBatchOverlap():
break
break
profile
.
ProfRangePush
(
'start'
)
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
self
.
tbo_thread_synchronize
(
tid
)
model_kwargs
=
None
if
is_left_thread
:
model_kwargs
=
self
.
model_kwargs_left
else
:
model_kwargs
=
self
.
model_kwargs_right
with
set_forward_context
(
model_input
.
attn_metadata
,
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
self
.
virtual_engine
):
self
.
vllm_config
,
self
.
virtual_engine
):
hidden_or_intermediate_states
=
self
.
model_executable
(
hidden_or_intermediate_states
=
self
.
model_executable
(
...
@@ -95,7 +100,7 @@ class TwoBatchOverlap():
...
@@ -95,7 +100,7 @@ class TwoBatchOverlap():
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
device
=
self
.
self_device
),
device
=
self
.
self_device
),
**
self
.
seqlen_agnostic_kwargs
,
**
self
.
seqlen_agnostic_kwargs
,
**
self
.
model_kwargs
,
**
model_kwargs
,
)
)
if
is_left_thread
:
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
...
@@ -131,7 +136,8 @@ class TwoBatchOverlap():
...
@@ -131,7 +136,8 @@ class TwoBatchOverlap():
multi_modal_kwargs
,
multi_modal_kwargs
,
self_device
,
self_device
,
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
model_kwargs
):
model_kwargs_left
,
model_kwargs_right
):
if
self
.
left_thread
==
None
:
if
self
.
left_thread
==
None
:
self
.
init_tbo_thread
()
self
.
init_tbo_thread
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
...
@@ -141,7 +147,8 @@ class TwoBatchOverlap():
...
@@ -141,7 +147,8 @@ class TwoBatchOverlap():
self
.
multi_modal_kwargs
=
multi_modal_kwargs
self
.
multi_modal_kwargs
=
multi_modal_kwargs
self
.
self_device
=
self_device
self
.
self_device
=
self_device
self
.
seqlen_agnostic_kwargs
=
seqlen_agnostic_kwargs
self
.
seqlen_agnostic_kwargs
=
seqlen_agnostic_kwargs
self
.
model_kwargs
=
model_kwargs
self
.
model_kwargs_left
=
model_kwargs_left
self
.
model_kwargs_right
=
model_kwargs_right
self
.
model_input_left_queue
.
put
(
model_input_left
)
self
.
model_input_left_queue
.
put
(
model_input_left
)
self
.
model_input_right_queue
.
put
(
model_input_right
)
self
.
model_input_right_queue
.
put
(
model_input_right
)
...
@@ -242,6 +249,16 @@ def tbo_model_executable(
...
@@ -242,6 +249,16 @@ def tbo_model_executable(
batch_size_right
+=
1
batch_size_right
+=
1
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
model_kwargs_left
=
model_kwargs
model_kwargs_right
=
model_kwargs
if
"previous_hidden_states"
in
model_kwargs
:
previous_hidden_states
=
model_kwargs
[
"previous_hidden_states"
]
batch_size_split
=
[
batch_size_left
,
batch_size_right
]
split_previous_hidden_states
=
torch
.
split
(
previous_hidden_states
,
batch_size_split
,
dim
=
0
)
model_kwargs_left
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
0
]
model_kwargs_right
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
1
]
tbo_obj
.
step_event
.
record
()
tbo_obj
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
...
@@ -255,7 +272,8 @@ def tbo_model_executable(
...
@@ -255,7 +272,8 @@ def tbo_model_executable(
multi_modal_kwargs
,
multi_modal_kwargs
,
self_device
,
self_device
,
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
model_kwargs
)
model_kwargs_left
,
model_kwargs_right
)
tbo_obj
.
all_reduce
()
tbo_obj
.
all_reduce
()
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment