Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe1c4016
Commit
fe1c4016
authored
Jul 10, 2025
by
zhuwenwen
Browse files
add two batch overlap decude support muti-stream cuda-graph
parent
d805c59c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
286 additions
and
17 deletions
+286
-17
vllm/two_batch_overlap/model_input_split.py
vllm/two_batch_overlap/model_input_split.py
+69
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+207
-13
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-4
No files found.
vllm/two_batch_overlap/model_input_split.py
View file @
fe1c4016
...
...
@@ -328,3 +328,72 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
is_prompt
=
model_input
.
is_prompt
,
)
return
model_input_left
,
model_input_right
def
split_capture_attention_metadata
(
attn_metadata
,
batch_size_left
,
batch_size_right
):
batch_size_split
=
[
batch_size_left
,
batch_size_right
]
split_seq_lens_tensor
=
torch
.
split
(
attn_metadata
.
seq_lens_tensor
,
batch_size_split
,
dim
=
0
)
split_block_tables
=
torch
.
split
(
attn_metadata
.
block_tables
,
batch_size_split
,
dim
=
0
)
split_slot_mapping
=
torch
.
split
(
attn_metadata
.
slot_mapping
,
batch_size_split
,
dim
=
0
)
if
isinstance
(
attn_metadata
,
ROCmFlashAttentionMetadata
):
attn_metadata_left
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
0
],
max_decode_seq_len
=
attn_metadata
.
max_decode_seq_len
,
block_tables
=
split_block_tables
[
0
],
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size_left
,
slot_mapping
=
split_slot_mapping
[
0
],
multi_modal_placeholder_index_maps
=
attn_metadata
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
None
,
max_prefill_seq_len
=
0
,
use_cuda_graph
=
attn_metadata
.
use_cuda_graph
,
max_query_len
=
1
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
max_decode_query_len
=
1
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
None
,
encoder_seq_lens
=
None
,
encoder_seq_lens_tensor
=
None
,
max_encoder_seq_len
=
None
,
num_encoder_tokens
=
None
,
cross_slot_mapping
=
None
,
cross_block_tables
=
None
,
)
attn_metadata_right
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
1
],
max_decode_seq_len
=
attn_metadata
.
max_decode_seq_len
,
block_tables
=
split_block_tables
[
1
],
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size_right
,
slot_mapping
=
split_slot_mapping
[
1
],
multi_modal_placeholder_index_maps
=
attn_metadata
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
None
,
max_prefill_seq_len
=
0
,
use_cuda_graph
=
attn_metadata
.
use_cuda_graph
,
max_query_len
=
1
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
max_decode_query_len
=
1
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
None
,
encoder_seq_lens
=
None
,
encoder_seq_lens_tensor
=
None
,
max_encoder_seq_len
=
None
,
num_encoder_tokens
=
None
,
cross_slot_mapping
=
None
,
cross_block_tables
=
None
,
)
else
:
print
(
"tbo:not surpport in cuda-graph "
,
type
(
attn_metadata
))
return
attn_metadata_left
,
attn_metadata_right
vllm/two_batch_overlap/two_batch_overlap.py
View file @
fe1c4016
import
gc
import
os
import
queue
import
threading
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.two_batch_overlap.model_input_split
import
is_supported_attention_metadata
,
split_model_input
from
vllm.two_batch_overlap.model_input_split
import
is_supported_attention_metadata
,
split_capture_attention_metadata
,
split_model_input
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
from
vllm.utils
import
weak_ref_tensor
tbo_one_stream
=
os
.
environ
.
get
(
'VLLM_TBO_ONE_STREAM'
)
==
'1'
...
...
@@ -37,6 +40,7 @@ class TwoBatchOverlap():
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_in_capture
=
False
if
tbo_step_stream
==
None
:
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -84,6 +88,27 @@ class TwoBatchOverlap():
else
:
model_kwargs
=
self
.
model_kwargs_right
intermediate_tensors
=
self
.
intermediate_tensors_right
hidden_or_intermediate_states
=
None
if
self
.
tbo_in_capture
:
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
input_tokens
=
self
.
input_tokens_left
input_positions
=
self
.
split_input_positions
[
0
]
else
:
attn_metadata
=
self
.
attn_metadata_right
input_tokens
=
self
.
input_tokens_right
input_positions
=
self
.
split_input_positions
[
1
]
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
self
.
virtual_engine
):
hidden_or_intermediate_states
=
self
.
model_executable
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
device
=
self
.
self_device
),
**
model_kwargs
,
)
elif
model_input
!=
None
:
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
self
.
virtual_engine
):
hidden_or_intermediate_states
=
self
.
model_executable
(
...
...
@@ -144,6 +169,37 @@ class TwoBatchOverlap():
self
.
model_input_left_queue
.
put
(
model_input_left
)
self
.
model_input_right_queue
.
put
(
model_input_right
)
def
set_capture_model_input
(
self
,
input_tokens_left
,
input_tokens_right
,
split_input_positions
,
vllm_config
,
virtual_engine
,
runner_model
,
runner_device
,
intermediate_tensors_left
,
intermediate_tensors_right
,
model_kwargs_left
,
model_kwargs_right
,
attn_metadata_left
,
attn_metadata_right
):
self
.
input_tokens_left
=
input_tokens_left
self
.
input_tokens_right
=
input_tokens_right
self
.
split_input_positions
=
split_input_positions
self
.
vllm_config
=
vllm_config
self
.
virtual_engine
=
virtual_engine
self
.
model_executable
=
runner_model
self
.
self_device
=
runner_device
self
.
intermediate_tensors_left
=
intermediate_tensors_left
self
.
intermediate_tensors_right
=
intermediate_tensors_right
self
.
model_kwargs_left
=
model_kwargs_left
self
.
model_kwargs_right
=
model_kwargs_right
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_right
=
attn_metadata_right
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
def
get_model_output
(
self
):
states_left
=
self
.
states_left_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
...
...
@@ -280,3 +336,141 @@ def tbo_model_executable(
current_stream
.
wait_event
(
tbo_obj
.
step_event
)
profile
.
ProfRangePop
()
return
hidden_or_intermediate_states
def
_run_once
(
vllm_config
,
virtual_engine
,
runner
,
self_device
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_inputs
:
Optional
[
IntermediateTensors
],
attn_metadata
:
AttentionMetadata
,
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
):
global
tbo_step_stream
stream_back
=
tbo_step_stream
tbo_step_stream
=
stream
init_two_batch_overlap
()
tbo_obj
.
left_first
=
True
decode_batch_size
=
input_ids
.
shape
[
0
]
batch_size_left
=
int
(
decode_batch_size
/
2
)
batch_size_right
=
decode_batch_size
-
batch_size_left
query_tokens_split
=
[
batch_size_left
,
batch_size_right
]
input_tokens_left
,
input_tokens_right
=
torch
.
split
(
input_ids
,
query_tokens_split
,
dim
=
0
)
split_input_positions
=
torch
.
split
(
positions
,
query_tokens_split
,
dim
=
0
)
model_kwargs_left
=
kwargs
.
copy
()
model_kwargs_right
=
kwargs
.
copy
()
intermediate_tensors_left
=
None
intermediate_tensors_right
=
None
if
"previous_hidden_states"
in
kwargs
:
previous_hidden_states
=
kwargs
[
"previous_hidden_states"
]
split_previous_hidden_states
=
torch
.
split
(
previous_hidden_states
,
query_tokens_split
,
dim
=
0
)
model_kwargs_left
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
0
]
model_kwargs_right
[
"previous_hidden_states"
]
=
split_previous_hidden_states
[
1
]
if
intermediate_inputs
!=
None
:
query_tokens_split
=
[
batch_size_left
,
batch_size_right
]
intermediate_tensors_left
=
{}
intermediate_tensors_right
=
{}
for
key
in
intermediate_inputs
.
tensors
:
split_intermediate_tensors
=
torch
.
split
(
intermediate_inputs
.
tensors
[
key
],
query_tokens_split
,
dim
=
0
)
intermediate_tensors_left
[
key
]
=
split_intermediate_tensors
[
0
]
intermediate_tensors_right
[
key
]
=
split_intermediate_tensors
[
1
]
intermediate_tensors_left
=
IntermediateTensors
(
intermediate_tensors_left
)
intermediate_tensors_right
=
IntermediateTensors
(
intermediate_tensors_right
)
attn_metadata_left
,
attn_metadata_right
=
split_capture_attention_metadata
(
attn_metadata
,
batch_size_left
,
batch_size_right
)
tbo_obj
.
tbo_running
=
True
tbo_obj
.
tbo_in_capture
=
True
tbo_obj
.
set_capture_model_input
(
input_tokens_left
,
input_tokens_right
,
split_input_positions
,
vllm_config
,
virtual_engine
,
runner
.
model
,
self_device
,
intermediate_tensors_left
,
intermediate_tensors_right
,
model_kwargs_left
,
model_kwargs_right
,
attn_metadata_left
,
attn_metadata_right
)
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
output_hidden_or_intermediate_states
=
merge_model_output
(
states_left
,
states_right
)
tbo_obj
.
tbo_in_capture
=
False
tbo_obj
.
tbo_running
=
False
tbo_obj
.
finish_thread
()
tbo_step_stream
=
stream_back
return
output_hidden_or_intermediate_states
def
tbo_capture
(
vllm_config
,
virtual_engine
,
_NUM_WARMUP_ITERS
,
runner
,
self_device
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_inputs
:
Optional
[
IntermediateTensors
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
):
for
i
in
range
(
_NUM_WARMUP_ITERS
):
_run_once
(
vllm_config
,
virtual_engine
,
runner
,
self_device
,
input_ids
,
positions
,
intermediate_inputs
,
attn_metadata
,
torch
.
cuda
.
current_stream
(),
**
kwargs
)
torch
.
cuda
.
synchronize
()
runner
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
runner
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
output_hidden_or_intermediate_states
=
_run_once
(
vllm_config
,
virtual_engine
,
runner
,
self_device
,
input_ids
,
positions
,
intermediate_inputs
,
attn_metadata
,
torch
.
cuda
.
current_stream
(),
**
kwargs
)
if
isinstance
(
output_hidden_or_intermediate_states
,
torch
.
Tensor
):
hidden_or_intermediate_states
=
weak_ref_tensor
(
output_hidden_or_intermediate_states
)
elif
isinstance
(
output_hidden_or_intermediate_states
,
IntermediateTensors
):
hidden_or_intermediate_states
=
IntermediateTensors
(
tensors
=
{
key
:
weak_ref_tensor
(
value
)
for
key
,
value
in
output_hidden_or_intermediate_states
.
tensors
.
items
()
})
del
output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc
.
collect
()
torch
.
cuda
.
synchronize
()
# Save the input and output buffers.
runner
.
input_buffers
=
{
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
**
runner
.
attn_state
.
get_graph_input_buffers
(
attn_metadata
,
runner
.
_is_encoder_decoder_model
),
**
kwargs
,
}
if
intermediate_inputs
is
not
None
:
runner
.
input_buffers
.
update
(
intermediate_inputs
.
tensors
)
if
get_pp_group
().
is_last_rank
:
runner
.
output_buffers
=
{
"hidden_states"
:
hidden_or_intermediate_states
}
else
:
runner
.
output_buffers
=
hidden_or_intermediate_states
vllm/worker/model_runner.py
View file @
fe1c4016
...
...
@@ -52,7 +52,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_model_executable
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_capture
,
tbo_model_executable
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_pin_memory_available
,
supports_dynamo
,
...
...
@@ -1668,6 +1668,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
if
envs
.
VLLM_ENABLE_TBO
and
envs
.
VLLM_TBO_DECODE_BS
>
1
and
batch_size
>=
envs
.
VLLM_TBO_DECODE_BS
:
tbo_capture
(
self
.
vllm_config
,
virtual_engine
,
_NUM_WARMUP_ITERS
,
graph_runner
,
self
.
device
,
**
capture_inputs
)
else
:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
graph_runner
.
capture
(
**
capture_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment