Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
12291212
Commit
12291212
authored
Oct 06, 2025
by
maxiao1
Committed by
lizhigong
Oct 10, 2025
Browse files
pd分离_tbo
parent
3daae57c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
378 additions
and
302 deletions
+378
-302
vllm/attention/layer.py
vllm/attention/layer.py
+8
-8
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+5
-2
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+33
-31
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+201
-175
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+130
-84
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-2
No files found.
vllm/attention/layer.py
View file @
12291212
...
@@ -414,9 +414,9 @@ def unified_attention(
...
@@ -414,9 +414,9 @@ def unified_attention(
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
attn_metadata
)
if
envs
.
VLLM_ENABLE_TBO
:
#
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
#
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else
:
#
else:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
return
output
return
output
...
@@ -462,9 +462,9 @@ def unified_attention_with_output(
...
@@ -462,9 +462,9 @@ def unified_attention_with_output(
attn_metadata
,
attn_metadata
,
output
=
output
,
output
=
output
,
output_scale
=
output_scale
)
output_scale
=
output_scale
)
if
envs
.
VLLM_ENABLE_TBO
:
#
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
#
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else
:
#
else:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
...
...
vllm/model_executor/layers/activation.py
View file @
12291212
...
@@ -75,6 +75,9 @@ class SiluAndMul(CustomOp):
...
@@ -75,6 +75,9 @@ class SiluAndMul(CustomOp):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
if
not
torch
.
compiler
.
is_compiling
():
# 非 capture 阶段
return
self
.
forward_cuda
(
x
)
# 强制走 fused kernel
else
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
...
...
vllm/model_executor/layers/layernorm.py
View file @
12291212
...
@@ -165,7 +165,11 @@ class RMSNorm(CustomOp):
...
@@ -165,7 +165,11 @@ class RMSNorm(CustomOp):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
if
not
torch
.
compiler
.
is_compiling
():
# 非 capture 阶段
return
self
.
forward_cuda
(
x
,
residual
)
# 强制走 fused kernel
else
:
# 否则fallback到原始实现
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
if
residual
is
not
None
:
...
@@ -184,11 +188,9 @@ class RMSNorm(CustomOp):
...
@@ -184,11 +188,9 @@ class RMSNorm(CustomOp):
raise
ValueError
(
raise
ValueError
(
"Expected hidden_size to be at least "
"Expected hidden_size to be at least "
f
"
{
self
.
variance_size_override
}
, but found:
{
hidden_size
}
"
)
f
"
{
self
.
variance_size_override
}
, but found:
{
hidden_size
}
"
)
x_var
=
x
[:,
:,
:
self
.
variance_size_override
]
x_var
=
x
[:,
:,
:
self
.
variance_size_override
]
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
x
=
x
.
to
(
orig_dtype
)
if
self
.
has_weight
:
if
self
.
has_weight
:
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
12291212
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -25,99 +24,79 @@ class TBOModelInputSplit():
...
@@ -25,99 +24,79 @@ class TBOModelInputSplit():
self
.
req_num_right
=
0
self
.
req_num_right
=
0
self
.
scheduler_output_left
=
None
self
.
scheduler_output_left
=
None
self
.
scheduler_output_right
=
None
self
.
scheduler_output_right
=
None
self
.
query_start_loc_right
=
Non
e
self
.
split_in_req
=
Fals
e
input_split
=
TBOModelInputSplit
()
input_split
=
TBOModelInputSplit
()
def
split_scheduler_output
(
runner
,
scheduler_output
:
SchedulerOutput
):
def
split_scheduler_output
(
runner
,
scheduler_output
:
SchedulerOutput
):
"""Split a step's scheduled tokens evenly into left/right halves.
If a request crosses the split boundary, mark split_in_req=True and
assign left/right token counts accordingly.
"""
split_tokens
=
scheduler_output
.
total_num_scheduled_tokens
//
2
split_tokens
=
scheduler_output
.
total_num_scheduled_tokens
//
2
req_ids
=
runner
.
input_batch
.
req_ids
split_counter
=
0
tokens_counter
=
0
num_scheduled_tokens_left
:
dict
[
int
,
int
]
=
{}
min_idx
=
-
1
num_scheduled_tokens_right
:
dict
[
int
,
int
]
=
{}
min_counter
=
0
input_split
.
req_ids_left
.
clear
()
for
i
,
id
in
enumerate
(
req_ids
):
input_split
.
req_ids_right
.
clear
()
tokens_counter
+=
scheduler_output
.
num_scheduled_tokens
[
id
]
total_num_scheduled_tokens_left
=
split_tokens
diff
=
abs
(
tokens_counter
-
split_tokens
)
total_num_scheduled_tokens_right
=
scheduler_output
.
total_num_scheduled_tokens
-
split_tokens
if
min_idx
==
-
1
or
diff
<
min_counter
:
min_idx
=
i
req_splited
=
False
min_counter
=
diff
input_split
.
split_in_req
=
False
if
tokens_counter
>
split_tokens
or
diff
==
0
:
break
input_split
.
req_num_left
=
min_idx
+
1
if
input_split
.
req_num_left
==
len
(
req_ids
):
input_split
.
req_num_left
=
input_split
.
req_num_left
-
1
input_split
.
req_ids_left
=
req_ids
[:
input_split
.
req_num_left
]
input_split
.
req_ids_right
=
req_ids
[
input_split
.
req_num_left
:]
input_split
.
req_num_right
=
len
(
req_ids
)
-
input_split
.
req_num_left
new_req_data_left
=
[]
new_req_data_right
=
[]
cached_reqs_left
=
[]
cached_reqs_right
=
[]
num_scheduled_tokens_left
=
{}
num_scheduled_tokens_right
=
{}
total_num_scheduled_tokens_left
=
0
total_num_scheduled_tokens_right
=
0
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
input_split
.
req_ids_left
:
new_req_data_left
.
append
(
new_req
)
else
:
new_req_data_right
.
append
(
new_req
)
cached_reqs_left
=
CachedRequestData
.
make_empty
()
cached_reqs_right
=
CachedRequestData
.
make_empty
()
for
req_idx
,
req_id
in
enumerate
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
):
if
req_id
in
input_split
.
req_ids_left
:
cached_reqs_left
.
req_ids
.
append
(
req_id
)
cached_reqs_left
.
resumed_from_preemption
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
resumed_from_preemption
[
req_idx
])
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
)
>
0
:
cached_reqs_left
.
new_token_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
[
req_idx
])
cached_reqs_left
.
new_block_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_block_ids
[
req_idx
])
cached_reqs_left
.
num_computed_tokens
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
req_idx
])
else
:
cached_reqs_right
.
req_ids
.
append
(
req_id
)
cached_reqs_right
.
resumed_from_preemption
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
resumed_from_preemption
[
req_idx
])
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
)
>
0
:
cached_reqs_right
.
new_token_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_token_ids
[
req_idx
])
cached_reqs_right
.
new_block_ids
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
new_block_ids
[
req_idx
])
cached_reqs_right
.
num_computed_tokens
.
append
(
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
req_idx
])
for
key
,
value
in
scheduler_output
.
num_scheduled_tokens
.
items
():
for
key
,
value
in
scheduler_output
.
num_scheduled_tokens
.
items
():
if
key
in
input_split
.
req_ids_left
:
split_counter
+=
value
if
split_counter
==
split_tokens
:
req_splited
=
True
num_scheduled_tokens_left
[
key
]
=
value
num_scheduled_tokens_left
[
key
]
=
value
total_num_scheduled_tokens_left
+=
value
input_split
.
req_ids_left
.
append
(
key
)
else
:
elif
split_counter
>
split_tokens
:
if
req_splited
:
# boundary already hit earlier; entire req goes to right
num_scheduled_tokens_right
[
key
]
=
value
num_scheduled_tokens_right
[
key
]
=
value
total_num_scheduled_tokens_right
+=
value
input_split
.
req_ids_right
.
append
(
key
)
else
:
# The boundary falls inside this request -> split within req
req_splited
=
True
input_split
.
split_in_req
=
True
right_tokens
=
split_counter
-
split_tokens
left_tokens
=
value
-
right_tokens
# right part
num_scheduled_tokens_right
[
key
]
=
right_tokens
input_split
.
req_ids_right
.
append
(
key
)
# left part
num_scheduled_tokens_left
[
key
]
=
left_tokens
input_split
.
req_ids_left
.
append
(
key
)
else
:
# before boundary, entire req goes to left
num_scheduled_tokens_left
[
key
]
=
value
input_split
.
req_ids_left
.
append
(
key
)
input_split
.
req_num_left
=
len
(
input_split
.
req_ids_left
)
input_split
.
req_num_right
=
len
(
input_split
.
req_ids_right
)
input_split
.
scheduler_output_left
=
SchedulerOutput
(
input_split
.
scheduler_output_left
=
SchedulerOutput
(
scheduled_new_reqs
=
ne
w_req_data_left
,
scheduled_new_reqs
=
No
ne
,
scheduled_cached_reqs
=
cached_reqs_left
,
scheduled_cached_reqs
=
None
,
num_scheduled_tokens
=
num_scheduled_tokens_left
,
num_scheduled_tokens
=
num_scheduled_tokens_left
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_left
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_left
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##
unsupport yet
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
#
unsupport
ed
yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
,
)
)
input_split
.
scheduler_output_right
=
SchedulerOutput
(
input_split
.
scheduler_output_right
=
SchedulerOutput
(
scheduled_new_reqs
=
ne
w_req_data_right
,
scheduled_new_reqs
=
No
ne
,
scheduled_cached_reqs
=
cached_reqs_right
,
scheduled_cached_reqs
=
None
,
num_scheduled_tokens
=
num_scheduled_tokens_right
,
num_scheduled_tokens
=
num_scheduled_tokens_right
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_right
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens_right
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_spec_decode_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
##
unsupport yet
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
,
#
unsupport
ed
yet
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
num_common_prefix_blocks
=
scheduler_output
.
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
finished_req_ids
=
scheduler_output
.
finished_req_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
free_encoder_input_ids
=
scheduler_output
.
free_encoder_input_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
structured_output_request_ids
=
scheduler_output
.
structured_output_request_ids
,
...
@@ -129,102 +108,159 @@ def prepare_tbo_atten_metadata(
...
@@ -129,102 +108,159 @@ def prepare_tbo_atten_metadata(
runner
,
runner
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
req_ids
,
req_ids
,
req_offset
req_offset
:
int
,
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
)
->
dict
[
str
,
Any
]:
# (attn_metadata)
"""Prepare attention metadata for one half (left/right).
Key fixes for correctness when a request is split:
- Align seq_len_offset / query_start_offset with block_table slicing.
- For the right half, if a request was split, make seq_lens[0]
= (history + left-prefix + right-half tokens).
- Pass cloned slices to CommonAttentionMetadata to avoid aliasing.
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
len
(
req_ids
)
num_reqs
=
len
(
req_ids
)
assert
num_reqs
>
0
assert
num_reqs
>
0
seq_len_offset
=
req_offset
# Tokens per req in THIS half
# Get the number of scheduled tokens for each request.
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
max_num_scheduled_tokens
=
max
(
tokens
)
if
req_offset
>
0
:
#right
# Request indices (relative to the WHOLE step), used by kernels
if
input_split
.
query_start_loc_right
==
None
:
req_indices
=
np
.
repeat
(
runner
.
arange_np
[:
num_reqs
],
# TODO: create when system init
num_scheduled_tokens
)
+
req_offset
input_split
.
query_start_loc_right
=
torch
.
zeros
(
runner
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
runner
.
device
)
cu_num_tokens
,
arange
=
runner
.
_get_cumsum_and_arange
(
# Cumulative tokens within this half
num_scheduled_tokens
)
cu_num_tokens
,
arange
=
runner
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
#
Prepare the attention metadata.
#
--- query_start_loc (within this half) ---
runner
.
query_start_loc_np
[
0
]
=
0
runner
.
query_start_loc_np
[
0
]
=
0
runner
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
runner
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
# --- seq_lens (absolute context length per-req row) ---
# Default (no split across req boundary)
# Maps rows [req_offset ... req_offset+num_reqs-1]
default_seq_lens
=
(
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
:
req_offset
+
num_reqs
]
+
num_scheduled_tokens
)
input_split
.
query_start_loc_right
[
0
:
num_reqs
+
1
].
copy_
(
# Offsets for copying into the *global* GPU buffers
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
# Left-half writes at the natural position; right-half depends on split.
# Note: pad query_start_loc to be non-decreasing, as kernels
if
req_offset
==
0
:
# like FlashAttention requires that
# LEFT
input_split
.
query_start_loc_right
[
num_reqs
+
1
:].
fill_
(
seq_len_offset
=
0
runner
.
query_start_loc_cpu
[
num_reqs
].
item
())
query_start_offset
=
0
query_start_loc
=
input_split
.
query_start_loc_right
[:
num_reqs
+
1
]
seq_lens_cpu_local
=
torch
.
as_tensor
(
default_seq_lens
,
device
=
runner
.
seq_lens_cpu
.
device
)
else
:
else
:
query_start_loc
=
runner
.
query_start_loc
[:
num_reqs
+
1
]
# RIGHT
if
input_split
.
split_in_req
:
# The block_table for RIGHT starts from (req_offset-1).
# Align both offsets to that, and re-build the seq_lens for row-0.
seq_len_offset
=
req_offset
-
1
query_start_offset
=
req_offset
-
1
# row-0 is the split request (global row index = req_offset-1):
base_hist
=
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
-
1
].
item
()
left_prefix
=
input_split
.
scheduler_output_left
.
num_scheduled_tokens
[
req_ids
[
0
]]
right_tokens0
=
scheduler_output
.
num_scheduled_tokens
[
req_ids
[
0
]]
first_row
=
base_hist
+
left_prefix
+
right_tokens0
if
num_reqs
>
1
:
# rows 1.. map to global rows [req_offset .. req_offset+num_reqs-2]
tail_base
=
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
:
req_offset
+
num_reqs
-
1
]
tail_tokens
=
num_scheduled_tokens
[
1
:]
tail
=
tail_base
+
tail_tokens
seq_lens_cpu_local
=
torch
.
empty
(
num_reqs
,
dtype
=
runner
.
seq_lens_cpu
.
dtype
,
device
=
runner
.
seq_lens_cpu
.
device
)
seq_lens_cpu_local
[
0
]
=
first_row
seq_lens_cpu_local
[
1
:]
=
torch
.
as_tensor
(
tail
,
device
=
runner
.
seq_lens_cpu
.
device
)
else
:
seq_lens_cpu_local
=
torch
.
tensor
([
first_row
],
dtype
=
runner
.
seq_lens_cpu
.
dtype
,
device
=
runner
.
seq_lens_cpu
.
device
)
else
:
# RIGHT without split-in-req: natural positions
seq_len_offset
=
req_offset
query_start_offset
=
req_offset
seq_lens_cpu_local
=
torch
.
as_tensor
(
default_seq_lens
,
device
=
runner
.
seq_lens_cpu
.
device
)
# Copy query_start_loc into global GPU buffer window
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
].
copy_
(
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
# Pad tail (FlashAttn requires non-decreasing)
if
req_offset
>
0
:
runner
.
query_start_loc
[
query_start_offset
+
num_reqs
+
1
:].
fill_
(
runner
.
query_start_loc_cpu
[
num_reqs
].
item
()
)
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
# Copy seq_lens into the aligned window; zero out the remainder on RIGHT
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
].
copy_
(
seq_lens_cpu_local
,
non_blocking
=
True
)
if
req_offset
>
0
:
runner
.
seq_lens
[
seq_len_offset
+
num_reqs
:].
fill_
(
0
)
# Build common metadata (pass CLONES to avoid aliasing between threads)
query_start_loc
=
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
].
clone
()
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
].
clone
()
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
)
max_query_len
=
max_num_scheduled_tokens
,
)
# Prepare attention metadata for each KV cache group
attn_metadata
:
dict
[
str
,
Any
]
=
{}
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
runner
.
kv_cache_config
.
kv_cache_groups
):
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
runner
.
kv_cache_config
.
kv_cache_groups
):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
kv_cache_group_spec
.
kv_cache_spec
,
metadata_builder
,
metadata_builder
,
)
)
# Slice block_table / slot_mapping for RIGHT half
if
req_offset
>
0
:
if
req_offset
>
0
:
origin_block_table
=
metadata_builder
.
block_table
.
block_table
origin_block_table
=
metadata_builder
.
block_table
.
block_table
if
input_split
.
split_in_req
:
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
-
1
:,
:]
else
:
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping_cpu
=
metadata_builder
.
block_table
.
slot_mapping_cpu
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
left_tokens
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
origin_slot_map_cpu
=
metadata_builder
.
block_table
.
slot_mapping_cpu
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
[
left_tokens
:]
metadata_builder
.
block_table
.
slot_mapping_cpu
=
\
metadata_builder
.
block_table
.
slot_mapping_cpu
=
origin_slot_mapping_cpu
[
left_tokens
:]
origin_slot_map_cpu
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
# MLA-specific counters (safe to ignore for Qwen/FA paths)
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
metadata_builder
.
build
(
attn_metadata_i
=
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
common_attn_metadata
=
common_attn_metadata
,
)
# Restore tables
if
req_offset
>
0
:
if
req_offset
>
0
:
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping_cpu
=
origin_slot_map_cpu
metadata_builder
.
block_table
.
slot_mapping_cpu
=
origin_slot_map
ping
_cpu
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
...
@@ -235,31 +271,27 @@ def prepare_tbo_atten_metadata(
...
@@ -235,31 +271,27 @@ def prepare_tbo_atten_metadata(
return
attn_metadata
return
attn_metadata
def
pad_num_input_tokens
(
self
,
scheduler_output
):
def
pad_num_input_tokens
(
self
,
scheduler_output
):
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# CUDA graphs (piecewise). Add padding to batch size.
# Use piecewise CUDA graphs.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
else
:
# Eager mode.
# Eager mode: pad to TP multiple for SP+collective fusion
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
and
tp_size
>
1
:
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
else
:
num_input_tokens
=
num_scheduled_tokens
num_input_tokens
=
num_scheduled_tokens
#
P
adding
for DP
#
DP p
adding
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
num_input_tokens
+=
num_pad
return
num_input_tokens
,
num_tokens_across_dp
return
num_input_tokens
,
num_tokens_across_dp
def
tbo_split_and_execute_model
(
def
tbo_split_and_execute_model
(
runner
,
runner
,
attn_metadata
,
attn_metadata
,
...
@@ -269,23 +301,39 @@ def tbo_split_and_execute_model(
...
@@ -269,23 +301,39 @@ def tbo_split_and_execute_model(
positions
,
positions
,
inputs_embeds
,
inputs_embeds
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
skip_cuda_graphs
:
bool
=
True
,
skip_cuda_graphs
:
bool
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
# Two-batch overlap path
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
and
num_input_tokens
>
envs
.
VLLM_TBO_MIN_TOKENS
:
split_scheduler_output
(
runner
,
scheduler_output
)
split_scheduler_output
(
runner
,
scheduler_output
)
use_tbo
=
True
if
use_tbo
:
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
num_input_tokens_right
=
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
attn_metadata_left
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_left
,
input_split
.
req_ids_left
,
0
)
attn_metadata_right
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_right
,
input_split
.
req_ids_right
,
input_split
.
req_num_left
)
attn_metadata_left
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_left
,
input_split
.
req_ids_left
,
0
)
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
attn_metadata_right
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_right
,
input_split
.
req_ids_right
,
input_split
.
req_num_left
)
# 真实 token
real_L
=
int
(
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
)
real_R
=
int
(
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
)
# 按左右半批切成两份
def
_split_it
(
it
,
l
,
r
):
if
it
is
None
:
return
None
,
None
lm
,
rm
=
{},
{}
for
k
,
v
in
it
.
tensors
.
items
():
vl
,
vr
=
torch
.
split
(
v
[:
l
+
r
],
[
l
,
r
],
dim
=
0
)
lm
[
k
],
rm
[
k
]
=
vl
,
vr
return
IntermediateTensors
(
lm
),
IntermediateTensors
(
rm
)
intermediate_tensors_left
,
intermediate_tensors_right
=
_split_it
(
intermediate_tensors
,
real_L
,
real_R
)
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
runner
.
vllm_config
,
...
@@ -303,33 +351,11 @@ def tbo_split_and_execute_model(
...
@@ -303,33 +351,11 @@ def tbo_split_and_execute_model(
num_tokens_across_dp
,
num_tokens_across_dp
,
input_ids
,
input_ids
,
positions
,
positions
,
intermediate_tensors
,
(
intermediate_tensors
_left
,
intermediate_tensors_right
)
,
inputs_embeds
)
inputs_embeds
)
runner
.
maybe_wait_for_kv_save
()
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
runner
.
get_finished_kv_transfers
(
scheduler_output
))
#finished_sending, finished_recving = None, None
else
:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
envs
.
VLLM_ENABLE_TBO
=
False
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
envs
.
VLLM_ENABLE_TBO
=
True
return
model_output
,
finished_sending
,
finished_recving
return
model_output
,
finished_sending
,
finished_recving
\ No newline at end of file
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
12291212
...
@@ -17,10 +17,12 @@ logger = init_logger(__name__)
...
@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream
=
None
tbo_step_stream
=
None
all_reduce_stream
=
None
all_reduce_stream
=
None
class
TwoBatchOverlap
():
PERSIST_THREADS
=
os
.
getenv
(
'VLLM_TBO_PERSIST_THREADS'
,
'1'
)
not
in
(
'0'
,
'false'
,
'False'
,
'no'
,
'NO'
,
''
)
STOP
=
object
()
class
TwoBatchOverlap
:
def
__init__
(
self
):
def
__init__
(
self
):
global
tbo_step_stream
global
tbo_step_stream
,
all_reduce_stream
global
all_reduce_stream
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
...
@@ -29,12 +31,14 @@ class TwoBatchOverlap():
...
@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self
.
right_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
self
.
left_tid
=
0
self
.
right_tid
=
0
self
.
right_tid
=
0
self
.
_stop_evt
=
threading
.
Event
()
self
.
_threads_started
=
False
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_running
=
False
self
.
tbo_in_capture
=
False
self
.
tbo_in_capture
=
False
if
tbo_step_stream
==
None
:
if
tbo_step_stream
is
None
:
tbo_step_stream
=
torch
.
cuda
.
Stream
()
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
...
@@ -44,35 +48,52 @@ class TwoBatchOverlap():
...
@@ -44,35 +48,52 @@ class TwoBatchOverlap():
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
def
init_tbo_thread
(
self
):
self
.
model_input_left_queue
.
empty
()
if
self
.
_threads_started
and
PERSIST_THREADS
:
self
.
model_input_right_queue
.
empty
()
return
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,))
if
self
.
left_thread
is
None
or
not
self
.
left_thread
.
is_alive
():
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,),
daemon
=
True
)
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
if
self
.
right_thread
is
None
or
not
self
.
right_thread
.
is_alive
():
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,),
daemon
=
True
)
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
if
get_tp_group
().
rank
==
0
:
self
.
_threads_started
=
True
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
shutdown
(
self
,
timeout
=
5.0
):
self
.
left_thread
.
join
()
self
.
_stop_evt
.
set
()
try
:
self
.
model_input_left_queue
.
put
(
STOP
)
self
.
model_input_right_queue
.
put
(
STOP
)
except
Exception
:
pass
if
self
.
left_thread
is
not
None
:
self
.
left_thread
.
join
(
timeout
=
timeout
)
self
.
left_thread
=
None
self
.
left_thread
=
None
self
.
right_thread
.
join
()
if
self
.
right_thread
is
not
None
:
self
.
right_thread
.
join
(
timeout
=
timeout
)
self
.
right_thread
=
None
self
.
right_thread
=
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
q
ueue
):
def
thread_two_batch_overlap
(
self
,
q
):
is_left_thread
=
False
is_left_thread
=
False
tid
=
threading
.
get_ident
()
tid
=
threading
.
get_ident
()
if
q
ueue
==
self
.
model_input_left_queue
:
if
q
is
self
.
model_input_left_queue
:
self
.
left_tid
=
tid
self
.
left_tid
=
tid
is_left_thread
=
True
is_left_thread
=
True
init_tbo_forward_context
(
True
,
self
.
left_tid
)
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
else
:
self
.
right_tid
=
tid
self
.
right_tid
=
tid
init_tbo_forward_context
(
False
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
while
not
self
.
_stop_evt
.
is_set
():
item
=
q
.
get
()
if
item
is
STOP
:
break
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
self
.
tbo_thread_synchronize
(
tid
)
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
attn_metadata
=
self
.
attn_metadata_left
num_input_tokens
=
self
.
num_input_tokens_left
num_input_tokens
=
self
.
num_input_tokens_left
...
@@ -84,20 +105,28 @@ class TwoBatchOverlap():
...
@@ -84,20 +105,28 @@ class TwoBatchOverlap():
input_ids
=
self
.
input_ids_right
input_ids
=
self
.
input_ids_right
positions
=
self
.
positions_right
positions
=
self
.
positions_right
model_output
=
None
# Select per-thread tensors (left/right) with backward-compatible fallback
# Run the decoder.
if
is_left_thread
:
# Use persistent buffers for CUDA graphs.
intermediate_tensors
=
getattr
(
self
,
'intermediate_tensors_left'
,
None
)
else
:
intermediate_tensors
=
getattr
(
self
,
'intermediate_tensors_right'
,
None
)
if
intermediate_tensors
is
None
:
intermediate_tensors
=
getattr
(
self
,
'intermediate_tensors_left'
,
None
)
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
skip_cuda_graphs
=
True
,
):
model_output
=
self
.
model_runner
.
model
(
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
self
.
inputs_embeds
,
inputs_embeds
=
self
.
inputs_embeds
,
)
)
if
is_left_thread
:
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
model_output
)
self
.
states_left_queue
.
put
(
model_output
)
...
@@ -128,7 +157,8 @@ class TwoBatchOverlap():
...
@@ -128,7 +157,8 @@ class TwoBatchOverlap():
positions_right
,
positions_right
,
num_tokens_across_dp
,
num_tokens_across_dp
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
):
inputs_embeds
,
):
self
.
model_runner
=
model_runner
self
.
model_runner
=
model_runner
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_right
=
attn_metadata_right
self
.
attn_metadata_right
=
attn_metadata_right
...
@@ -139,9 +169,14 @@ class TwoBatchOverlap():
...
@@ -139,9 +169,14 @@ class TwoBatchOverlap():
self
.
positions_left
=
positions_left
self
.
positions_left
=
positions_left
self
.
positions_right
=
positions_right
self
.
positions_right
=
positions_right
self
.
num_tokens_across_dp
=
num_tokens_across_dp
self
.
num_tokens_across_dp
=
num_tokens_across_dp
self
.
intermediate_tensors
=
intermediate_tensors
self
.
inputs_embeds
=
inputs_embeds
self
.
inputs_embeds
=
inputs_embeds
if
isinstance
(
intermediate_tensors
,
tuple
):
self
.
intermediate_tensors_left
,
self
.
intermediate_tensors_right
=
intermediate_tensors
else
:
self
.
intermediate_tensors_left
=
intermediate_tensors
self
.
intermediate_tensors_right
=
None
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
...
@@ -150,15 +185,18 @@ class TwoBatchOverlap():
...
@@ -150,15 +185,18 @@ class TwoBatchOverlap():
states_right
=
self
.
states_right_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
return
states_left
,
states_right
tbo_obj_v1
=
None
tbo_obj_v1
=
None
def
is_enable_tbo_v1
():
def
is_enable_tbo_v1
():
global
tbo_obj_v1
global
tbo_obj_v1
return
tbo_obj_v1
!=
None
return
tbo_obj_v1
is
not
None
def
init_two_batch_overlap
():
def
init_two_batch_overlap
():
global
tbo_obj_v1
global
tbo_obj_v1
if
tbo_obj_v1
==
None
:
if
tbo_obj_v1
is
None
:
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
.
init_tbo_thread
()
tbo_obj_v1
.
init_tbo_thread
()
...
@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
...
@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
def
tbo_all_reduce_v1
(
obj
):
def
tbo_all_reduce_v1
(
obj
):
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
is
not
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
if
tid
==
tbo_obj_v1
.
left_tid
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_left_c2t
,
tbo_obj_v1
.
event_left_t2c
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_left_c2t
,
tbo_obj_v1
.
event_left_t2c
...
@@ -207,19 +245,19 @@ def tbo_model_executable_v1(
...
@@ -207,19 +245,19 @@ def tbo_model_executable_v1(
input_ids
,
input_ids
,
positions
,
positions
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
inputs_embeds
,
):
):
init_two_batch_overlap
()
init_two_batch_overlap
()
tbo_obj_v1
.
tbo_running
=
True
tbo_obj_v1
.
tbo_running
=
True
tbo_obj_v1
.
left_first
=
True
tbo_obj_v1
.
left_first
=
True
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
num_total_tokens
=
num_input_tokens_left
+
num_input_tokens_right
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
tbo_step_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
tbo_step_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
tokens_split
=
[
num_input_tokens_left
,
num_input_tokens_right
]
tokens_split
=
[
num_input_tokens_left
,
num_input_tokens_right
]
input_ids_left
,
input_ids_right
=
torch
.
split
(
input_ids
,
tokens_split
,
dim
=
0
)
input_ids_left
,
input_ids_right
=
torch
.
split
(
input_ids
[:
num_total_tokens
]
,
tokens_split
,
dim
=
0
)
positions_left
,
positions_right
=
torch
.
split
(
positions
,
tokens_split
,
dim
=
0
)
positions_left
,
positions_right
=
torch
.
split
(
positions
[:
num_total_tokens
]
,
tokens_split
,
dim
=
0
)
tbo_obj_v1
.
set_model_input
(
model_runner
,
tbo_obj_v1
.
set_model_input
(
model_runner
,
attn_metadata_left
,
attn_metadata_left
,
attn_metadata_right
,
attn_metadata_right
,
...
@@ -231,13 +269,21 @@ def tbo_model_executable_v1(
...
@@ -231,13 +269,21 @@ def tbo_model_executable_v1(
positions_right
,
positions_right
,
num_tokens_across_dp
,
num_tokens_across_dp
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
,
)
model_output_left
,
model_output_right
=
tbo_obj_v1
.
get_model_output
()
model_output_left
,
model_output_right
=
tbo_obj_v1
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
model_output_left
,
model_output_right
)
hidden_or_intermediate_states
=
merge_model_output
(
model_output_left
,
model_output_right
)
tbo_obj_v1
.
tbo_running
=
False
tbo_obj_v1
.
tbo_running
=
False
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
finish_thread
()
current_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
current_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
return
hidden_or_intermediate_states
return
hidden_or_intermediate_states
def
finalize_two_batch_overlap
():
global
tbo_obj_v1
if
tbo_obj_v1
is
not
None
:
try
:
tbo_obj_v1
.
shutdown
()
finally
:
tbo_obj_v1
=
None
vllm/v1/worker/gpu_model_runner.py
View file @
12291212
...
@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
scheduler_output
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
):
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment