Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a906ab9
Commit
1a906ab9
authored
Jun 04, 2025
by
lizhigong
Browse files
fix tbo to support pipeline-parallel
parent
0c5b1695
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
28 deletions
+41
-28
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+3
-0
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+4
-0
vllm/two_batch_overlap/model_input_split.py
vllm/two_batch_overlap/model_input_split.py
+2
-23
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+32
-5
No files found.
vllm/entrypoints/launcher.py
View file @
1a906ab9
...
@@ -77,6 +77,9 @@ async def serve_http(app: FastAPI,
...
@@ -77,6 +77,9 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:
\n
%s"
,
"port %s is used by process %s launched with command:
\n
%s"
,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"Shutting down FastAPI HTTP server."
)
logger
.
info
(
"Shutting down FastAPI HTTP server."
)
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
finish_two_batch_overlap
()
return
server
.
shutdown
()
return
server
.
shutdown
()
finally
:
finally
:
watchdog_task
.
cancel
()
watchdog_task
.
cancel
()
...
...
vllm/executor/multiproc_worker_utils.py
View file @
1a906ab9
...
@@ -256,6 +256,10 @@ def _run_worker_process(
...
@@ -256,6 +256,10 @@ def _run_worker_process(
and
not
tunable
.
record_untuned_is_enabled
()):
and
not
tunable
.
record_untuned_is_enabled
()):
tunable
.
write_file
()
tunable
.
write_file
()
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
finish_two_batch_overlap
()
logger
.
info
(
"Worker exiting"
)
logger
.
info
(
"Worker exiting"
)
...
...
vllm/two_batch_overlap/model_input_split.py
View file @
1a906ab9
...
@@ -91,13 +91,6 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -91,13 +91,6 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
else
:
else
:
request_ids_to_seq_ids_right
[
key
]
=
value
request_ids_to_seq_ids_right
[
key
]
=
value
counter
+=
1
counter
+=
1
seq_groups_left
=
None
seq_groups_right
=
None
if
model_input
.
sampling_metadata
.
seq_groups
is
not
None
:
seq_groups_left
=
model_input
.
sampling_metadata
.
seq_groups
[
0
:
batch_size_left
]
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
previous_hidden_states_left
=
None
previous_hidden_states_left
=
None
previous_hidden_states_right
=
None
previous_hidden_states_right
=
None
...
@@ -310,14 +303,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -310,14 +303,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback
=
model_input
.
async_callback
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
previous_hidden_states_left
,
previous_hidden_states
=
previous_hidden_states_left
,
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
None
,
#TBO does not require sampling_stetadata
seq_groups
=
seq_groups_left
,
selected_token_indices
=
selected_token_indices_left
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_left
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
is_prompt
=
model_input
.
is_prompt
,
)
)
model_input_right
=
ModelInputForGPUWithSamplingMetadata
(
model_input_right
=
ModelInputForGPUWithSamplingMetadata
(
...
@@ -338,14 +324,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
...
@@ -338,14 +324,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback
=
model_input
.
async_callback
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
previous_hidden_states_right
,
previous_hidden_states
=
previous_hidden_states_right
,
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
None
,
#TBO does not require sampling_stetadata
seq_groups
=
seq_groups_right
,
selected_token_indices
=
selected_token_indices_right
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_right
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
is_prompt
=
model_input
.
is_prompt
,
)
)
return
model_input_left
,
model_input_right
return
model_input_left
,
model_input_right
vllm/two_batch_overlap/two_batch_overlap.py
View file @
1a906ab9
...
@@ -4,8 +4,10 @@ import queue
...
@@ -4,8 +4,10 @@ import queue
import
threading
import
threading
import
torch
import
torch
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.two_batch_overlap.model_input_split
import
is_supported_attention_metadata
,
split_model_input
from
vllm.two_batch_overlap.model_input_split
import
is_supported_attention_metadata
,
split_model_input
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -87,16 +89,20 @@ class TwoBatchOverlap():
...
@@ -87,16 +89,20 @@ class TwoBatchOverlap():
profile
.
ProfRangePush
(
'start'
)
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
self
.
tbo_thread_synchronize
(
tid
)
model_kwargs
=
None
model_kwargs
=
None
intermediate_tensors
=
None
if
is_left_thread
:
if
is_left_thread
:
model_kwargs
=
self
.
model_kwargs_left
model_kwargs
=
self
.
model_kwargs_left
intermediate_tensors
=
self
.
intermediate_tensors_left
else
:
else
:
model_kwargs
=
self
.
model_kwargs_right
model_kwargs
=
self
.
model_kwargs_right
intermediate_tensors
=
self
.
intermediate_tensors_right
with
set_forward_context
(
model_input
.
attn_metadata
,
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
self
.
virtual_engine
):
self
.
vllm_config
,
self
.
virtual_engine
):
hidden_or_intermediate_states
=
self
.
model_executable
(
hidden_or_intermediate_states
=
self
.
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
device
=
self
.
self_device
),
device
=
self
.
self_device
),
**
self
.
seqlen_agnostic_kwargs
,
**
self
.
seqlen_agnostic_kwargs
,
...
@@ -132,7 +138,8 @@ class TwoBatchOverlap():
...
@@ -132,7 +138,8 @@ class TwoBatchOverlap():
vllm_config
,
vllm_config
,
virtual_engine
,
virtual_engine
,
model_executable
,
model_executable
,
intermediate_tensors
,
intermediate_tensors_left
,
intermediate_tensors_right
,
multi_modal_kwargs
,
multi_modal_kwargs
,
self_device
,
self_device
,
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
...
@@ -143,7 +150,8 @@ class TwoBatchOverlap():
...
@@ -143,7 +150,8 @@ class TwoBatchOverlap():
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
virtual_engine
=
virtual_engine
self
.
virtual_engine
=
virtual_engine
self
.
model_executable
=
model_executable
self
.
model_executable
=
model_executable
self
.
intermediate_tensors
=
intermediate_tensors
self
.
intermediate_tensors_left
=
intermediate_tensors_left
self
.
intermediate_tensors_right
=
intermediate_tensors_right
self
.
multi_modal_kwargs
=
multi_modal_kwargs
self
.
multi_modal_kwargs
=
multi_modal_kwargs
self
.
self_device
=
self_device
self
.
self_device
=
self_device
self
.
seqlen_agnostic_kwargs
=
seqlen_agnostic_kwargs
self
.
seqlen_agnostic_kwargs
=
seqlen_agnostic_kwargs
...
@@ -204,7 +212,13 @@ def tbo_all_reduce(obj):
...
@@ -204,7 +212,13 @@ def tbo_all_reduce(obj):
return
tensor_model_parallel_all_reduce
(
obj
)
return
tensor_model_parallel_all_reduce
(
obj
)
def
merge_model_output
(
states_left
,
states_right
):
def
merge_model_output
(
states_left
,
states_right
):
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
if
isinstance
(
states_left
,
IntermediateTensors
):
output_map
=
{}
for
key
in
states_left
.
tensors
:
output_map
[
key
]
=
torch
.
concat
([
states_left
.
tensors
[
key
],
states_right
.
tensors
[
key
]],
dim
=
0
)
output
=
IntermediateTensors
(
output_map
)
else
:
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
return
output
def
tbo_model_executable
(
def
tbo_model_executable
(
...
@@ -252,12 +266,24 @@ def tbo_model_executable(
...
@@ -252,12 +266,24 @@ def tbo_model_executable(
model_kwargs_left
=
model_kwargs
.
copy
()
model_kwargs_left
=
model_kwargs
.
copy
()
model_kwargs_right
=
model_kwargs
.
copy
()
model_kwargs_right
=
model_kwargs
.
copy
()
intermediate_tensors_left
=
None
intermediate_tensors_right
=
None
if
"previous_hidden_states"
in
model_kwargs
:
if
"previous_hidden_states"
in
model_kwargs
:
previous_hidden_states
=
model_kwargs
[
"previous_hidden_states"
]
previous_hidden_states
=
model_kwargs
[
"previous_hidden_states"
]
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
split_previous_hidden_states
=
torch
.
split
(
previous_hidden_states
,
query_tokens_split
,
dim
=
0
)
split_previous_hidden_states
=
torch
.
split
(
previous_hidden_states
,
query_tokens_split
,
dim
=
0
)
model_kwargs_left
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
0
]
model_kwargs_left
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
0
]
model_kwargs_right
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
1
]
model_kwargs_right
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
1
]
if
intermediate_tensors
!=
None
:
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
intermediate_tensors_left
=
{}
intermediate_tensors_right
=
{}
for
key
in
intermediate_tensors
.
tensors
:
split_intermediate_tensors
=
torch
.
split
(
intermediate_tensors
.
tensors
[
key
],
query_tokens_split
,
dim
=
0
)
intermediate_tensors_left
[
key
]
=
split_intermediate_tensors
[
0
]
intermediate_tensors_right
[
key
]
=
split_intermediate_tensors
[
1
]
intermediate_tensors_left
=
IntermediateTensors
(
intermediate_tensors_left
)
intermediate_tensors_right
=
IntermediateTensors
(
intermediate_tensors_right
)
tbo_obj
.
step_event
.
record
()
tbo_obj
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -268,7 +294,8 @@ def tbo_model_executable(
...
@@ -268,7 +294,8 @@ def tbo_model_executable(
vllm_config
,
vllm_config
,
virtual_engine
,
virtual_engine
,
model_executable
,
model_executable
,
intermediate_tensors
,
intermediate_tensors_left
,
intermediate_tensors_right
,
multi_modal_kwargs
,
multi_modal_kwargs
,
self_device
,
self_device
,
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment