Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
88443051
Commit
88443051
authored
Oct 11, 2025
by
zhuwenwen
Browse files
remove two_batch_overlap
parent
1bd3ae33
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
10 additions
and
1553 deletions
+10
-1553
vllm/attention/layer.py
vllm/attention/layer.py
+3
-10
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+0
-3
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+0
-2
vllm/forward_context.py
vllm/forward_context.py
+0
-15
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+1
-7
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+1
-6
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+3
-9
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+2
-7
vllm/two_batch_overlap/forward_context.py
vllm/two_batch_overlap/forward_context.py
+0
-35
vllm/two_batch_overlap/model_input_split.py
vllm/two_batch_overlap/model_input_split.py
+0
-399
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+0
-481
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+0
-335
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+0
-243
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+0
-1
No files found.
vllm/attention/layer.py
View file @
88443051
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_maybe_save_kv_layer_to_connector
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -575,9 +574,6 @@ def unified_attention(
...
@@ -575,9 +574,6 @@ def unified_attention(
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
attn_metadata
)
if
envs
.
VLLM_ENABLE_TBO
:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
else
:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
return
output
return
output
...
@@ -625,9 +621,6 @@ def unified_attention_with_output(
...
@@ -625,9 +621,6 @@ def unified_attention_with_output(
output_scale
=
output_scale
,
output_scale
=
output_scale
,
output_block_scale
=
output_block_scale
)
output_block_scale
=
output_block_scale
)
if
envs
.
VLLM_ENABLE_TBO
:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
else
:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
...
...
vllm/compilation/decorators.py
View file @
88443051
...
@@ -223,9 +223,6 @@ def _support_torch_compile(
...
@@ -223,9 +223,6 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
if
envs
.
VLLM_ENABLE_TBO
and
get_forward_context
().
skip_cuda_graphs
:
return
self
.
forward
(
*
args
,
**
kwargs
)
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
88443051
...
@@ -273,8 +273,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -273,8 +273,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
torch.Tensor: A tensor containing the extracted KV slices.
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
Returns None if the layout is unsupported.
"""
"""
if
envs
.
VLLM_ENABLE_TBO
:
slot_mapping
=
slot_mapping
.
pin_memory
().
to
(
device
=
layer
.
device
,
non_blocking
=
True
)
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
layer
.
shape
[
1
]
==
2
):
# MLA or FlashInfer
or
layer
.
shape
[
1
]
==
2
):
# MLA or FlashInfer
return
layer
[
block_ids
,
...]
return
layer
[
block_ids
,
...]
...
...
vllm/forward_context.py
View file @
88443051
...
@@ -17,7 +17,6 @@ from vllm.logger import init_logger
...
@@ -17,7 +17,6 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.ubatch_utils
import
UBatchSlices
,
is_second_ubatch_empty
from
vllm.v1.worker.ubatch_utils
import
UBatchSlices
,
is_second_ubatch_empty
from
vllm.two_batch_overlap.forward_context
import
get_tbo_forward_context
,
set_tbo_forward_context
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -293,14 +292,6 @@ _forward_context: Optional[ForwardContext] = None
...
@@ -293,14 +292,6 @@ _forward_context: Optional[ForwardContext] = None
def
get_forward_context
()
->
ForwardContext
:
def
get_forward_context
()
->
ForwardContext
:
if
envs
.
VLLM_ENABLE_TBO
:
forward_context
=
get_tbo_forward_context
()
"""Get the current forward context."""
assert
forward_context
is
not
None
,
(
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return
forward_context
"""Get the current forward context."""
"""Get the current forward context."""
assert
_forward_context
is
not
None
,
(
assert
_forward_context
is
not
None
,
(
"Forward context is not set. "
"Forward context is not set. "
...
@@ -371,8 +362,6 @@ def set_forward_context(
...
@@ -371,8 +362,6 @@ def set_forward_context(
virtual_engine
,
dp_metadata
,
virtual_engine
,
dp_metadata
,
cudagraph_runtime_mode
,
cudagraph_runtime_mode
,
batch_descriptor
,
ubatch_slices
)
batch_descriptor
,
ubatch_slices
)
if
envs
.
VLLM_ENABLE_TBO
:
set_tbo_forward_context
(
forward_context
)
try
:
try
:
with
override_forward_context
(
forward_context
):
with
override_forward_context
(
forward_context
):
...
@@ -414,10 +403,6 @@ def set_forward_context(
...
@@ -414,10 +403,6 @@ def set_forward_context(
"(batchsize, count, median_time(ms)): %s"
),
"(batchsize, count, median_time(ms)): %s"
),
forward_stats
)
forward_stats
)
if
envs
.
VLLM_ENABLE_TBO
:
set_tbo_forward_context
(
_forward_context
)
_profiling
:
bool
=
False
_profiling
:
bool
=
False
@
contextmanager
@
contextmanager
...
...
vllm/model_executor/layers/linear.py
View file @
88443051
...
@@ -1478,9 +1478,6 @@ class RowParallelLinear(LinearBase):
...
@@ -1478,9 +1478,6 @@ class RowParallelLinear(LinearBase):
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
update_param_tp_status
()
self
.
update_param_tp_status
()
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
...
@@ -1568,9 +1565,6 @@ class RowParallelLinear(LinearBase):
...
@@ -1568,9 +1565,6 @@ class RowParallelLinear(LinearBase):
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias_
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
88443051
...
@@ -328,8 +328,6 @@ class VocabParallelEmbedding(CustomOp):
...
@@ -328,8 +328,6 @@ class VocabParallelEmbedding(CustomOp):
self
.
num_embeddings_padded
,
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
@
classmethod
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
...
@@ -473,9 +471,6 @@ class VocabParallelEmbedding(CustomOp):
...
@@ -473,9 +471,6 @@ class VocabParallelEmbedding(CustomOp):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
return
output
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
88443051
...
@@ -260,9 +260,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -260,9 +260,6 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts
=
self
.
n_redundant_experts
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -306,9 +303,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -306,9 +303,6 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
,
0
)
final_hidden_states
,
0
)
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
elif
self
.
tp_size
>
1
:
elif
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
(
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
final_hidden_states
))
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
88443051
...
@@ -149,8 +149,6 @@ class DeepseekV3MoE(nn.Module):
...
@@ -149,8 +149,6 @@ class DeepseekV3MoE(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
)
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -165,9 +163,6 @@ class DeepseekV3MoE(nn.Module):
...
@@ -165,9 +163,6 @@ class DeepseekV3MoE(nn.Module):
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
...
...
vllm/two_batch_overlap/forward_context.py
deleted
100644 → 0
View file @
1bd3ae33
import
threading
_forward_context_left
=
None
_forward_context_right
=
None
_left_tid
=
0
_right_tid
=
0
def
init_tbo_forward_context
(
left_flag
,
tid
):
global
_left_tid
global
_right_tid
if
left_flag
:
_left_tid
=
tid
else
:
_right_tid
=
tid
def
set_tbo_forward_context
(
_forward_context
):
global
_forward_context_left
global
_forward_context_right
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
_forward_context_left
=
_forward_context
else
:
_forward_context_right
=
_forward_context
def
get_tbo_forward_context
():
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
return
_forward_context_left
else
:
return
_forward_context_right
vllm/two_batch_overlap/model_input_split.py
deleted
100644 → 0
View file @
1bd3ae33
This diff is collapsed.
Click to expand it.
vllm/two_batch_overlap/two_batch_overlap.py
deleted
100644 → 0
View file @
1bd3ae33
This diff is collapsed.
Click to expand it.
vllm/two_batch_overlap/v1/model_input_split_v1.py
deleted
100644 → 0
View file @
1bd3ae33
from
typing
import
Any
,
Optional
,
Union
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadataBuilder
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.worker.block_table
import
BlockTable
class
TBOModelInputSplit
():
def
__init__
(
self
):
self
.
req_ids_left
=
[]
self
.
req_ids_right
=
[]
self
.
req_num_left
=
0
self
.
req_num_right
=
0
self
.
scheduler_output_left
=
None
self
.
scheduler_output_right
=
None
self
.
query_start_loc_right
=
None
input_split
=
TBOModelInputSplit
()
def
split_scheduler_output
(
runner
,
scheduler_output
:
SchedulerOutput
):
split_tokens
=
scheduler_output
.
total_num_scheduled_tokens
//
2
req_ids
=
runner
.
input_batch
.
req_ids
tokens_counter
=
0
min_idx
=
-
1
min_counter
=
0
for
i
,
id
in
enumerate
(
req_ids
):
tokens_counter
+=
scheduler_output
.
num_scheduled_tokens
[
id
]
diff
=
abs
(
tokens_counter
-
split_tokens
)
if
min_idx
==
-
1
or
diff
<
min_counter
:
min_idx
=
i
min_counter
=
diff
if
tokens_counter
>
split_tokens
or
diff
==
0
:
break
input_split
.
req_num_left
=
min_idx
+
1
if
input_split
.
req_num_left
==
len
(
req_ids
):
input_split
.
req_num_left
=
input_split
.
req_num_left
-
1
input_split
.
req_ids_left
=
req_ids
[:
input_split
.
req_num_left
]
input_split
.
req_ids_right
=
req_ids
[
input_split
.
req_num_left
:]
input_split
.
req_num_right
=
len
(
req_ids
)
-
input_split
.
req_num_left
new_req_data_left
=
[]
new_req_data_right
=
[]
cached_reqs_left
=
[]
cached_reqs_right
=
[]
num_scheduled_tokens_left
=
{}
num_scheduled_tokens_right
=
{}
total_num_scheduled_tokens_left
=
0
total_num_scheduled_tokens_right
=
0
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
input_split
.
req_ids_left
:
new_req_data_left
.
append
(
new_req
)
else
:
new_req_data_right
.
append
(
new_req
)
cached_reqs_left
=
CachedRequestData
.
make_empty
()
cached_reqs_right
=
CachedRequestData
.
make_empty
()
for
req_idx
,
req_id
in
enumerate
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
):
if
req_id
in
input_split
.
req_ids_left
:
cached_reqs_left
.
req_ids
.
append
(
req_id
)
cached_reqs_left
.
resumed_from_preemption
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
resumed_from_preemption
[
req_idx
])
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
)
>
0
:
cached_reqs_left
.
new_token_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
[
req_idx
])
cached_reqs_left
.
new_block_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_block_ids
[
req_idx
])
cached_reqs_left
.
num_computed_tokens
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
req_idx
])
else
:
cached_reqs_right
.
req_ids
.
append
(
req_id
)
cached_reqs_right
.
resumed_from_preemption
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
resumed_from_preemption
[
req_idx
])
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
)
>
0
:
cached_reqs_right
.
new_token_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
[
req_idx
])
cached_reqs_right
.
new_block_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_block_ids
[
req_idx
])
cached_reqs_right
.
num_computed_tokens
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
req_idx
])
for
key
,
value
in
scheduler_output
.
num_scheduled_tokens
.
items
():
if
key
in
input_split
.
req_ids_left
:
num_scheduled_tokens_left
[
key
]
=
value
total_num_scheduled_tokens_left
+=
value
else
:
num_scheduled_tokens_right
[
key
]
=
value
total_num_scheduled_tokens_right
+=
value
input_split
.
scheduler_output_left
=
SchedulerOutput
(
scheduled_new_reqs
=
new_req_data_left
,
scheduled_cached_reqs
=
cached_reqs_left
,
num_scheduled_tokens
=
num_scheduled_tokens_left
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_left
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##unsupport yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
)
input_split
.
scheduler_output_right
=
SchedulerOutput
(
scheduled_new_reqs
=
new_req_data_right
,
scheduled_cached_reqs
=
cached_reqs_right
,
num_scheduled_tokens
=
num_scheduled_tokens_right
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_right
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##unsupport yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
)
def
prepare_tbo_atten_metadata
(
runner
,
scheduler_output
:
"SchedulerOutput"
,
req_ids
,
req_offset
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
num_reqs
=
len
(
req_ids
)
assert
num_reqs
>
0
seq_len_offset
=
req_offset
# Get the number of scheduled tokens for each request.
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
if
req_offset
>
0
:
#right
if
input_split
.
query_start_loc_right
==
None
:
# TODO: create when system init
input_split
.
query_start_loc_right
=
torch
.
zeros
(
runner
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
runner
.
device
)
cu_num_tokens
,
arange
=
runner
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
# Prepare the attention metadata.
runner
.
query_start_loc_np
[
0
]
=
0
runner
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
input_split
.
query_start_loc_right
[
0
:
num_reqs
+
1
].
copy_
(
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
input_split
.
query_start_loc_right
[
num_reqs
+
1
:].
fill_
(
runner
.
query_start_loc_cpu
[
num_reqs
].
item
())
query_start_loc
=
input_split
.
query_start_loc_right
[:
num_reqs
+
1
]
else
:
query_start_loc
=
runner
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
runner
.
kv_cache_config
.
kv_cache_groups
):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
metadata_builder
,
)
if
req_offset
>
0
:
origin_block_table
=
metadata_builder
.
block_table
.
block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_map_cpu
=
metadata_builder
.
block_table
.
slot_mapping_cpu
metadata_builder
.
block_table
.
slot_mapping_cpu
=
\
origin_slot_map_cpu
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping_cpu
=
origin_slot_map_cpu
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_prefill_tokens
=
_num_prefill_tokens_record
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
return
attn_metadata
def
pad_num_input_tokens
(
self
,
scheduler_output
):
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
return
num_input_tokens
,
num_tokens_across_dp
def
tbo_split_and_execute_model
(
runner
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
inputs_embeds
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
skip_cuda_graphs
:
bool
=
True
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
and
num_input_tokens
>
envs
.
VLLM_TBO_MIN_TOKENS
:
split_scheduler_output
(
runner
,
scheduler_output
)
use_tbo
=
True
if
use_tbo
:
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
attn_metadata_left
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_left
,
input_split
.
req_ids_left
,
0
)
attn_metadata_right
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_right
,
input_split
.
req_ids_right
,
input_split
.
req_num_left
)
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
tbo_model_executable_v1
(
runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
#finished_sending, finished_recving = None, None
else
:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
envs
.
VLLM_ENABLE_TBO
=
False
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
envs
.
VLLM_ENABLE_TBO
=
True
return
model_output
,
finished_sending
,
finished_recving
\ No newline at end of file
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
deleted
100644 → 0
View file @
1bd3ae33
import
os
import
queue
import
threading
import
torch
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
tbo_step_stream
=
None
all_reduce_stream
=
None
class
TwoBatchOverlap
():
def
__init__
(
self
):
global
tbo_step_stream
global
all_reduce_stream
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_right_queue
=
queue
.
Queue
()
self
.
left_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
self
.
right_tid
=
0
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_in_capture
=
False
if
tbo_step_stream
==
None
:
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
self
.
model_input_left_queue
.
empty
()
self
.
model_input_right_queue
.
empty
()
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,))
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
=
None
self
.
right_thread
.
join
()
self
.
right_thread
=
None
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
queue
):
is_left_thread
=
False
tid
=
threading
.
get_ident
()
if
queue
==
self
.
model_input_left_queue
:
self
.
left_tid
=
tid
is_left_thread
=
True
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
self
.
right_tid
=
tid
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
num_input_tokens
=
self
.
num_input_tokens_left
input_ids
=
self
.
input_ids_left
positions
=
self
.
positions_left
else
:
attn_metadata
=
self
.
attn_metadata_right
num_input_tokens
=
self
.
num_input_tokens_right
input_ids
=
self
.
input_ids_right
positions
=
self
.
positions_right
model_output
=
None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
inputs_embeds
=
self
.
inputs_embeds
,
)
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
model_output
)
else
:
self
.
states_right_queue
.
put
(
model_output
)
def
tbo_thread_synchronize
(
self
,
tid
):
if
tid
==
self
.
left_tid
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
left_first
=
False
self
.
sem_left
.
acquire
()
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
self
.
sem_left
.
release
()
self
.
sem_right
.
acquire
()
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
):
self
.
model_runner
=
model_runner
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_right
=
attn_metadata_right
self
.
num_input_tokens_left
=
num_input_tokens_left
self
.
num_input_tokens_right
=
num_input_tokens_right
self
.
input_ids_left
=
input_ids_left
self
.
input_ids_right
=
input_ids_right
self
.
positions_left
=
positions_left
self
.
positions_right
=
positions_right
self
.
num_tokens_across_dp
=
num_tokens_across_dp
self
.
intermediate_tensors
=
intermediate_tensors
self
.
inputs_embeds
=
inputs_embeds
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
def
get_model_output
(
self
):
states_left
=
self
.
states_left_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
tbo_obj_v1
=
None
def
is_enable_tbo_v1
():
global
tbo_obj_v1
return
tbo_obj_v1
!=
None
def
init_two_batch_overlap
():
global
tbo_obj_v1
if
tbo_obj_v1
==
None
:
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
.
init_tbo_thread
()
def
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
):
from
vllm.attention.layer
import
maybe_save_kv_layer_to_connector
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
return
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
def
tbo_all_reduce_v1
(
obj
):
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_left_c2t
,
tbo_obj_v1
.
event_left_t2c
else
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_right_c2t
,
tbo_obj_v1
.
event_right_t2c
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
all_reduce_stream
):
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
event_t2c
.
record
()
tbo_obj_v1
.
tbo_thread_synchronize
(
tid
)
tbo_step_stream
.
wait_event
(
event_t2c
)
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
def
merge_model_output
(
states_left
,
states_right
):
if
isinstance
(
states_left
,
IntermediateTensors
):
output_map
=
{}
for
key
in
states_left
.
tensors
:
output_map
[
key
]
=
torch
.
concat
([
states_left
.
tensors
[
key
],
states_right
.
tensors
[
key
]],
dim
=
0
)
output
=
IntermediateTensors
(
output_map
)
else
:
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
def
tbo_model_executable_v1
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
):
init_two_batch_overlap
()
tbo_obj_v1
.
tbo_running
=
True
tbo_obj_v1
.
left_first
=
True
tbo_obj_v1
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
tbo_step_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
tokens_split
=
[
num_input_tokens_left
,
num_input_tokens_right
]
input_ids_left
,
input_ids_right
=
torch
.
split
(
input_ids
,
tokens_split
,
dim
=
0
)
positions_left
,
positions_right
=
torch
.
split
(
positions
,
tokens_split
,
dim
=
0
)
tbo_obj_v1
.
set_model_input
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
)
model_output_left
,
model_output_right
=
tbo_obj_v1
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
model_output_left
,
model_output_right
)
tbo_obj_v1
.
tbo_running
=
False
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
finish_thread
()
current_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
return
hidden_or_intermediate_states
\ No newline at end of file
vllm/v1/worker/gpu_model_runner.py
View file @
88443051
...
@@ -109,7 +109,6 @@ from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
...
@@ -109,7 +109,6 @@ from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
ubatch_split
)
ubatch_split
)
from
vllm.v1.worker.ubatch_utils
import
UBatchSlice
,
UBatchSlices
from
vllm.v1.worker.ubatch_utils
import
UBatchSlice
,
UBatchSlices
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
add_kv_sharing_layers_to_kv_cache_groups
,
bind_kv_cache
,
add_kv_sharing_layers_to_kv_cache_groups
,
bind_kv_cache
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment