Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20e75ed6
Commit
20e75ed6
authored
Aug 02, 2025
by
lizhigong
Committed by
maxiao1@sugon.com
Aug 04, 2025
Browse files
add tbo on v1 engine
parent
eba84521
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
892 additions
and
2 deletions
+892
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+3
-0
vllm/two_batch_overlap/v1/gpu_model_runner.py
vllm/two_batch_overlap/v1/gpu_model_runner.py
+638
-0
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+239
-0
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+7
-2
No files found.
vllm/envs.py
View file @
20e75ed6
...
...
@@ -159,6 +159,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
VLLM_TBO_DECODE_BS
:
int
=
0
VLLM_TBO_MIN_TOKENS
:
int
=
200
VLLM_ZERO_OVERHEAD
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
...
...
@@ -1069,6 +1070,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_DECODE_BS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_DECODE_BS"
,
"0"
)),
# set the minimum tokens size for each mini-batch to enable TBO on v1, default is 200.
"VLLM_TBO_MIN_TOKENS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_MIN_TOKENS"
,
"200"
)),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
...
...
vllm/two_batch_overlap/two_batch_overlap.py
View file @
20e75ed6
...
...
@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
from
vllm.utils
import
weak_ref_tensor
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
is_enable_tbo_v1
,
tbo_all_reduce_v1
tbo_one_stream
=
os
.
environ
.
get
(
'VLLM_TBO_ONE_STREAM'
)
==
'1'
...
...
@@ -214,6 +215,8 @@ def init_two_batch_overlap():
tbo_obj
.
init_tbo_thread
()
def
tbo_all_reduce
(
obj
):
if
is_enable_tbo_v1
():
return
tbo_all_reduce_v1
(
obj
)
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj
!=
None
and
tbo_obj
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
not
tbo_one_stream
:
...
...
vllm/two_batch_overlap/v1/gpu_model_runner.py
0 → 100644
View file @
20e75ed6
from
typing
import
Any
,
Optional
,
Union
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
class
TBO_GPUModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
,
device
):
super
().
__init__
(
vllm_config
,
device
)
self
.
req_ids_left
=
[]
self
.
req_ids_right
=
[]
self
.
req_num_left
=
0
self
.
req_num_right
=
0
self
.
scheduler_output_left
=
None
self
.
scheduler_output_right
=
None
def
split_scheduler_output
(
self
,
scheduler_output
:
SchedulerOutput
):
split_tokens
=
scheduler_output
.
total_num_scheduled_tokens
//
2
req_ids
=
self
.
input_batch
.
req_ids
tokens_counter
=
0
min_idx
=
-
1
min_counter
=
0
for
i
,
id
in
enumerate
(
req_ids
):
tokens_counter
+=
scheduler_output
.
num_scheduled_tokens
[
id
]
diff
=
abs
(
tokens_counter
-
split_tokens
)
if
min_idx
==
-
1
or
diff
<
min_counter
:
min_idx
=
i
min_counter
=
diff
if
tokens_counter
>
split_tokens
or
diff
==
0
:
break
self
.
req_num_left
=
min_idx
+
1
if
self
.
req_num_left
==
len
(
req_ids
):
self
.
req_num_left
=
self
.
req_num_left
-
1
self
.
req_ids_left
=
req_ids
[:
self
.
req_num_left
]
self
.
req_ids_right
=
req_ids
[
self
.
req_num_left
:]
self
.
req_num_right
=
len
(
req_ids
)
-
self
.
req_num_left
new_req_data_left
=
[]
new_req_data_right
=
[]
cached_reqs_left
=
[]
cached_reqs_right
=
[]
num_scheduled_tokens_left
=
{}
num_scheduled_tokens_right
=
{}
total_num_scheduled_tokens_left
=
0
total_num_scheduled_tokens_right
=
0
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
self
.
req_ids_left
:
new_req_data_left
.
append
(
new_req
)
else
:
new_req_data_right
.
append
(
new_req
)
for
cached_req
in
scheduler_output
.
scheduled_cached_reqs
:
if
cached_req
.
req_id
in
self
.
req_ids_left
:
cached_reqs_left
.
append
(
cached_req
)
else
:
cached_reqs_right
.
append
(
cached_req
)
for
key
,
value
in
scheduler_output
.
num_scheduled_tokens
.
items
():
if
key
in
self
.
req_ids_left
:
num_scheduled_tokens_left
[
key
]
=
value
total_num_scheduled_tokens_left
+=
value
else
:
num_scheduled_tokens_right
[
key
]
=
value
total_num_scheduled_tokens_right
+=
value
self
.
scheduler_output_left
=
SchedulerOutput
(
scheduled_new_reqs
=
new_req_data_left
,
scheduled_cached_reqs
=
cached_reqs_left
,
num_scheduled_tokens
=
num_scheduled_tokens_left
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_left
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##unsupport yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
)
self
.
scheduler_output_right
=
SchedulerOutput
(
scheduled_new_reqs
=
new_req_data_right
,
scheduled_cached_reqs
=
cached_reqs_right
,
num_scheduled_tokens
=
num_scheduled_tokens_right
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_right
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##unsupport yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
)
def
prepare_tbo_atten_metadata
(
self
,
scheduler_output
:
"SchedulerOutput"
,
req_ids
,
req_offset
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
num_reqs
=
len
(
req_ids
)
assert
num_reqs
>
0
seq_len_offset
=
req_offset
if
req_offset
==
0
:
#left
query_start_offset
=
0
else
:
query_start_offset
=
req_offset
+
1
# Get the number of scheduled tokens for each request.
# req_ids = self.input_batch.req_ids
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
num_scheduled_tokens
)
+
req_offset
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens
,
arange
=
self
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
# Get positions.
positions_np
=
self
.
positions_np
[:
total_num_scheduled_tokens
]
np
.
add
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_indices
],
arange
,
out
=
positions_np
)
# Calculate the slot mapping for each KV cache group.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
block_size
=
kv_cache_group_spec
.
kv_cache_spec
.
block_size
block_table
:
BlockTable
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices
=
(
req_indices
*
block_table
.
max_num_blocks_per_req
+
positions_np
//
block_size
)
block_table_cpu
=
block_table
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
(
)[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
block_size
np
.
add
(
block_numbers
*
block_size
,
block_offsets
,
out
=
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
:
req_offset
+
num_reqs
]
+
num_scheduled_tokens
)
self
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
].
copy_
(
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
if
req_offset
>
0
:
#right
self
.
query_start_loc
[
query_start_offset
+
num_reqs
+
1
:].
fill_
(
self
.
query_start_loc_cpu
[
num_reqs
].
item
())
self
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache
if
req_offset
>
0
:
#right
self
.
seq_lens
[
seq_len_offset
+
num_reqs
:].
fill_
(
0
)
query_start_loc
=
self
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
self
.
attn_metadata_builders
[
kv_cache_group_id
],
)
if
req_offset
>
0
:
origin_block_table
=
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
self
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
self
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
origin_slot_mapping
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
return
attn_metadata
def
pad_num_input_tokens
(
self
,
scheduler_output
):
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
return
num_input_tokens
,
num_tokens_across_dp
@
torch
.
inference_mode
()
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
self
.
kv_connector_no_forward
(
scheduler_output
)
assert
not
self
.
use_cuda_graph
,
'v1 engine with tbo do not support cuda-graph'
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
super
().
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embeds
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
else
:
mm_embeds
=
[]
if
self
.
is_multimodal_model
and
get_pp_group
().
is_first_rank
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
mm_embeds
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
mm_embeds
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
input_ids
=
None
else
:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
else
:
positions
=
self
.
positions
[:
num_input_tokens
]
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
else
:
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
use_tbo
=
False
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
self
.
split_scheduler_output
(
scheduler_output
)
if
self
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
self
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
use_tbo
:
num_input_tokens_left
=
self
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
attn_metadata_left
=
self
.
prepare_tbo_atten_metadata
(
self
.
scheduler_output_left
,
self
.
req_ids_left
,
0
)
attn_metadata_right
=
self
.
prepare_tbo_atten_metadata
(
self
.
scheduler_output_right
,
self
.
req_ids_right
,
self
.
req_num_left
)
model_output
=
tbo_model_executable_v1
(
self
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
finished_sending
,
finished_recving
=
None
,
None
else
:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output
=
\
self
.
parallel_config
.
distributed_executor_backend
\
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
if
not
broadcast_pp_output
:
return
hidden_states
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
broadcast_pp_output
:
model_output_broadcast_data
=
{
"logits"
:
logits
.
contiguous
(),
}
if
logits
is
not
None
else
{}
model_output_broadcast_data
=
get_pp_group
().
broadcast_tensor_dict
(
model_output_broadcast_data
,
src
=
len
(
get_pp_group
().
ranks
)
-
1
)
assert
model_output_broadcast_data
is
not
None
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
if
scheduler_output
.
grammar_bitmask
is
not
None
:
self
.
apply_grammar_bitmask
(
scheduler_output
,
logits
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
spec_decode_metadata
is
None
:
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
else
:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
sampling_metadata
=
sampling_metadata
,
)
bonus_token_ids
=
sampler_output
.
sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
if
logprobs_tensors
is
not
None
else
None
# Compute prompt logprobs if needed.
prompt_logprobs_dict
=
self
.
_get_prompt_logprobs_dict
(
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
)
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
self
.
input_batch
.
vocab_size
,
)
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
elif
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
generate_draft_token_ids
(
valid_sampled_token_ids
,
sampling_metadata
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
max_gen_len
==
1
:
hidden_states
=
sample_hidden_states
else
:
indices
=
[]
offset
=
0
for
num_draft
,
tokens
in
zip
(
spec_decode_metadata
.
num_draft_tokens
,
valid_sampled_token_ids
):
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
offset
+=
num_draft
+
1
indices
=
torch
.
tensor
(
indices
,
device
=
sample_hidden_states
.
device
)
hidden_states
=
sample_hidden_states
[
indices
]
spec_token_ids
=
self
.
drafter
.
propose
(
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
valid_sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
block_table
=
eagle_attn_metadata
.
block_table
else
:
block_table
=
None
num_rejected_tokens_tuple
=
None
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_rejected_tokens
=
[
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens_tensor
=
async_tensor_h2d
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens_tensor
,
num_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
num_rejected_tokens_tuple
=
(
num_rejected_tokens
,
num_rejected_tokens_tensor
)
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens_tuple
=
num_rejected_tokens_tuple
)
spec_token_ids
=
draft_token_ids
.
tolist
()
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
)
\ No newline at end of file
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
0 → 100644
View file @
20e75ed6
import
os
import
queue
import
threading
import
torch
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
tbo_step_stream
=
None
all_reduce_stream
=
None
class
TwoBatchOverlap
():
def
__init__
(
self
):
global
tbo_step_stream
global
all_reduce_stream
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_right_queue
=
queue
.
Queue
()
self
.
left_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
self
.
right_tid
=
0
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_in_capture
=
False
if
tbo_step_stream
==
None
:
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
self
.
model_input_left_queue
.
empty
()
self
.
model_input_right_queue
.
empty
()
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,))
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
=
None
self
.
right_thread
.
join
()
self
.
right_thread
=
None
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
queue
):
is_left_thread
=
False
tid
=
threading
.
get_ident
()
if
queue
==
self
.
model_input_left_queue
:
self
.
left_tid
=
tid
is_left_thread
=
True
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
self
.
right_tid
=
tid
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
num_input_tokens
=
self
.
num_input_tokens_left
input_ids
=
self
.
input_ids_left
positions
=
self
.
positions_left
else
:
attn_metadata
=
self
.
attn_metadata_right
num_input_tokens
=
self
.
num_input_tokens_right
input_ids
=
self
.
input_ids_right
positions
=
self
.
positions_right
model_output
=
None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
inputs_embeds
=
self
.
inputs_embeds
,
)
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
model_output
)
else
:
self
.
states_right_queue
.
put
(
model_output
)
profile
.
ProfRangePop
()
def
tbo_thread_synchronize
(
self
,
tid
):
if
tid
==
self
.
left_tid
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
left_first
=
False
profile
.
ProfRangePop
()
self
.
sem_left
.
acquire
()
profile
.
ProfRangePush
(
'left'
)
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
profile
.
ProfRangePush
(
'right'
)
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
):
self
.
model_runner
=
model_runner
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_right
=
attn_metadata_right
self
.
num_input_tokens_left
=
num_input_tokens_left
self
.
num_input_tokens_right
=
num_input_tokens_right
self
.
input_ids_left
=
input_ids_left
self
.
input_ids_right
=
input_ids_right
self
.
positions_left
=
positions_left
self
.
positions_right
=
positions_right
self
.
num_tokens_across_dp
=
num_tokens_across_dp
self
.
intermediate_tensors
=
intermediate_tensors
self
.
inputs_embeds
=
inputs_embeds
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
def
get_model_output
(
self
):
states_left
=
self
.
states_left_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
tbo_obj_v1
=
None
def
is_enable_tbo_v1
():
global
tbo_obj_v1
return
tbo_obj_v1
!=
None
def
init_two_batch_overlap
():
global
tbo_obj_v1
if
tbo_obj_v1
==
None
:
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
.
init_tbo_thread
()
def
tbo_all_reduce_v1
(
obj
):
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_left_c2t
,
tbo_obj_v1
.
event_left_t2c
else
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_right_c2t
,
tbo_obj_v1
.
event_right_t2c
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
all_reduce_stream
):
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
event_t2c
.
record
()
tbo_obj_v1
.
tbo_thread_synchronize
(
tid
)
tbo_step_stream
.
wait_event
(
event_t2c
)
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
def
merge_model_output
(
states_left
,
states_right
):
if
isinstance
(
states_left
,
IntermediateTensors
):
output_map
=
{}
for
key
in
states_left
.
tensors
:
output_map
[
key
]
=
torch
.
concat
([
states_left
.
tensors
[
key
],
states_right
.
tensors
[
key
]],
dim
=
0
)
output
=
IntermediateTensors
(
output_map
)
else
:
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
def
tbo_model_executable_v1
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
):
init_two_batch_overlap
()
tbo_obj_v1
.
tbo_running
=
True
tbo_obj_v1
.
left_first
=
True
tbo_obj_v1
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
tbo_step_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
tokens_split
=
[
num_input_tokens_left
,
num_input_tokens_right
]
input_ids_left
,
input_ids_right
=
torch
.
split
(
input_ids
,
tokens_split
,
dim
=
0
)
positions_left
,
positions_right
=
torch
.
split
(
positions
,
tokens_split
,
dim
=
0
)
tbo_obj_v1
.
set_model_input
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
)
model_output_left
,
model_output_right
=
tbo_obj_v1
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
model_output_left
,
model_output_right
)
tbo_obj_v1
.
tbo_running
=
False
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
finish_thread
()
current_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
return
hidden_or_intermediate_states
\ No newline at end of file
vllm/v1/worker/gpu_worker.py
View file @
20e75ed6
...
...
@@ -22,6 +22,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.gpu_model_runner
import
TBO_GPUModelRunner
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -162,8 +163,12 @@ class Worker(WorkerBase):
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
envs
.
VLLM_ENABLE_TBO
:
self
.
model_runner
:
TBO_GPUModelRunner
=
TBO_GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment