Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20e75ed6
Commit
20e75ed6
authored
Aug 02, 2025
by
lizhigong
Committed by
maxiao1@sugon.com
Aug 04, 2025
Browse files
add tbo on v1 engine
parent
eba84521
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
892 additions
and
2 deletions
+892
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+3
-0
vllm/two_batch_overlap/v1/gpu_model_runner.py
vllm/two_batch_overlap/v1/gpu_model_runner.py
+638
-0
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+239
-0
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+7
-2
No files found.
vllm/envs.py
View file @
20e75ed6
...
...
@@ -159,6 +159,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
VLLM_TBO_DECODE_BS
:
int
=
0
VLLM_TBO_MIN_TOKENS
:
int
=
200
VLLM_ZERO_OVERHEAD
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
...
...
@@ -1069,6 +1070,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_DECODE_BS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_DECODE_BS"
,
"0"
)),
# set the minimum tokens size for each mini-batch to enable TBO on v1, default is 200.
"VLLM_TBO_MIN_TOKENS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_MIN_TOKENS"
,
"200"
)),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
...
...
vllm/two_batch_overlap/two_batch_overlap.py
View file @
20e75ed6
...
...
@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
from
vllm.utils
import
weak_ref_tensor
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
is_enable_tbo_v1
,
tbo_all_reduce_v1
tbo_one_stream
=
os
.
environ
.
get
(
'VLLM_TBO_ONE_STREAM'
)
==
'1'
...
...
@@ -214,6 +215,8 @@ def init_two_batch_overlap():
tbo_obj
.
init_tbo_thread
()
def
tbo_all_reduce
(
obj
):
if
is_enable_tbo_v1
():
return
tbo_all_reduce_v1
(
obj
)
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj
!=
None
and
tbo_obj
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
not
tbo_one_stream
:
...
...
vllm/two_batch_overlap/v1/gpu_model_runner.py
0 → 100644
View file @
20e75ed6
This diff is collapsed.
Click to expand it.
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
0 → 100644
View file @
20e75ed6
import
os
import
queue
import
threading
import
torch
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
tbo_step_stream
=
None
all_reduce_stream
=
None
class
TwoBatchOverlap
():
def
__init__
(
self
):
global
tbo_step_stream
global
all_reduce_stream
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_right_queue
=
queue
.
Queue
()
self
.
left_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
self
.
right_tid
=
0
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_in_capture
=
False
if
tbo_step_stream
==
None
:
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
self
.
model_input_left_queue
.
empty
()
self
.
model_input_right_queue
.
empty
()
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,))
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
=
None
self
.
right_thread
.
join
()
self
.
right_thread
=
None
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
queue
):
is_left_thread
=
False
tid
=
threading
.
get_ident
()
if
queue
==
self
.
model_input_left_queue
:
self
.
left_tid
=
tid
is_left_thread
=
True
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
self
.
right_tid
=
tid
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
num_input_tokens
=
self
.
num_input_tokens_left
input_ids
=
self
.
input_ids_left
positions
=
self
.
positions_left
else
:
attn_metadata
=
self
.
attn_metadata_right
num_input_tokens
=
self
.
num_input_tokens_right
input_ids
=
self
.
input_ids_right
positions
=
self
.
positions_right
model_output
=
None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
inputs_embeds
=
self
.
inputs_embeds
,
)
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
model_output
)
else
:
self
.
states_right_queue
.
put
(
model_output
)
profile
.
ProfRangePop
()
def
tbo_thread_synchronize
(
self
,
tid
):
if
tid
==
self
.
left_tid
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
left_first
=
False
profile
.
ProfRangePop
()
self
.
sem_left
.
acquire
()
profile
.
ProfRangePush
(
'left'
)
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
profile
.
ProfRangePush
(
'right'
)
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
):
self
.
model_runner
=
model_runner
self
.
attn_metadata_left
=
attn_metadata_left
self
.
attn_metadata_right
=
attn_metadata_right
self
.
num_input_tokens_left
=
num_input_tokens_left
self
.
num_input_tokens_right
=
num_input_tokens_right
self
.
input_ids_left
=
input_ids_left
self
.
input_ids_right
=
input_ids_right
self
.
positions_left
=
positions_left
self
.
positions_right
=
positions_right
self
.
num_tokens_across_dp
=
num_tokens_across_dp
self
.
intermediate_tensors
=
intermediate_tensors
self
.
inputs_embeds
=
inputs_embeds
self
.
model_input_left_queue
.
put
(
None
)
self
.
model_input_right_queue
.
put
(
None
)
def
get_model_output
(
self
):
states_left
=
self
.
states_left_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
tbo_obj_v1
=
None
def
is_enable_tbo_v1
():
global
tbo_obj_v1
return
tbo_obj_v1
!=
None
def
init_two_batch_overlap
():
global
tbo_obj_v1
if
tbo_obj_v1
==
None
:
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
.
init_tbo_thread
()
def
tbo_all_reduce_v1
(
obj
):
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_left_c2t
,
tbo_obj_v1
.
event_left_t2c
else
:
event_c2t
,
event_t2c
=
tbo_obj_v1
.
event_right_c2t
,
tbo_obj_v1
.
event_right_t2c
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
all_reduce_stream
):
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
event_t2c
.
record
()
tbo_obj_v1
.
tbo_thread_synchronize
(
tid
)
tbo_step_stream
.
wait_event
(
event_t2c
)
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
def
merge_model_output
(
states_left
,
states_right
):
if
isinstance
(
states_left
,
IntermediateTensors
):
output_map
=
{}
for
key
in
states_left
.
tensors
:
output_map
[
key
]
=
torch
.
concat
([
states_left
.
tensors
[
key
],
states_right
.
tensors
[
key
]],
dim
=
0
)
output
=
IntermediateTensors
(
output_map
)
else
:
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
def
tbo_model_executable_v1
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
):
init_two_batch_overlap
()
tbo_obj_v1
.
tbo_running
=
True
tbo_obj_v1
.
left_first
=
True
tbo_obj_v1
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
tbo_step_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
tokens_split
=
[
num_input_tokens_left
,
num_input_tokens_right
]
input_ids_left
,
input_ids_right
=
torch
.
split
(
input_ids
,
tokens_split
,
dim
=
0
)
positions_left
,
positions_right
=
torch
.
split
(
positions
,
tokens_split
,
dim
=
0
)
tbo_obj_v1
.
set_model_input
(
model_runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
input_ids_left
,
input_ids_right
,
positions_left
,
positions_right
,
num_tokens_across_dp
,
intermediate_tensors
,
inputs_embeds
)
model_output_left
,
model_output_right
=
tbo_obj_v1
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
model_output_left
,
model_output_right
)
tbo_obj_v1
.
tbo_running
=
False
tbo_obj_v1
.
step_event
.
record
()
tbo_obj_v1
.
finish_thread
()
current_stream
.
wait_event
(
tbo_obj_v1
.
step_event
)
return
hidden_or_intermediate_states
\ No newline at end of file
vllm/v1/worker/gpu_worker.py
View file @
20e75ed6
...
...
@@ -22,6 +22,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.gpu_model_runner
import
TBO_GPUModelRunner
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -162,6 +163,10 @@ class Worker(WorkerBase):
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
if
envs
.
VLLM_ENABLE_TBO
:
self
.
model_runner
:
TBO_GPUModelRunner
=
TBO_GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment