Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa906d98
Commit
aa906d98
authored
Jun 12, 2025
by
lizhigong
Browse files
add VLLM_TBO_DECODE_BS to support and setting the min bs on tbo decode
parent
59488cc9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
32 deletions
+22
-32
vllm/envs.py
vllm/envs.py
+6
-2
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+16
-30
No files found.
vllm/envs.py
View file @
aa906d98
...
...
@@ -126,8 +126,8 @@ if TYPE_CHECKING:
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_FLASH_ATTN_BACKEND
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
VLLM_TBO_DECODE_BS
:
int
=
0
VLLM_ZERO_OVERHEAD
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
...
...
@@ -823,6 +823,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_REQ_DELAY_MS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_REQ_DELAY_MS"
,
"0"
)),
# set the minimum batch size to enable TBO in decode, if < 2 , disable TBO in decode.
"VLLM_TBO_DECODE_BS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_DECODE_BS"
,
"0"
)),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
...
...
vllm/two_batch_overlap/two_batch_overlap.py
View file @
aa906d98
...
...
@@ -14,8 +14,6 @@ from vllm.logger import init_logger
from
vllm.profiler.prof
import
profile
from
vllm
import
envs
enable_tbo_decode
=
os
.
environ
.
get
(
'VLLM_TBO_DECODE'
)
==
'1'
tbo_one_stream
=
os
.
environ
.
get
(
'VLLM_TBO_ONE_STREAM'
)
==
'1'
logger
=
init_logger
(
__name__
)
...
...
@@ -31,8 +29,6 @@ class TwoBatchOverlap():
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_right_queue
=
queue
.
Queue
()
self
.
all_reduce_queue
=
queue
.
Queue
()
self
.
all_reduce_out
=
queue
.
Queue
()
self
.
left_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
...
...
@@ -103,7 +99,6 @@ class TwoBatchOverlap():
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
hidden_or_intermediate_states
)
else
:
self
.
all_reduce_queue
.
put
(
None
)
self
.
states_right_queue
.
put
(
hidden_or_intermediate_states
)
profile
.
ProfRangePop
()
...
...
@@ -154,22 +149,6 @@ class TwoBatchOverlap():
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
def
all_reduce
(
self
):
while
True
:
obj
=
self
.
all_reduce_queue
.
get
()
if
obj
==
None
:
break
buf
,
event_c2t
,
event_t2c
=
obj
if
tbo_one_stream
:
output
=
tensor_model_parallel_all_reduce
(
buf
)
else
:
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
all_reduce_stream
):
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
event_t2c
.
record
()
self
.
all_reduce_out
.
put
(
output
)
tbo_obj
=
None
def
init_two_batch_overlap
():
...
...
@@ -186,11 +165,16 @@ def tbo_all_reduce(obj):
event_c2t
,
event_t2c
=
tbo_obj
.
event_left_c2t
,
tbo_obj
.
event_left_t2c
else
:
event_c2t
,
event_t2c
=
tbo_obj
.
event_right_c2t
,
tbo_obj
.
event_right_t2c
tbo_obj
.
all_reduce_queue
.
put
([
obj
,
event_c2t
,
event_t2c
])
output
=
tbo_obj
.
all_reduce_out
.
get
()
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
all_reduce_stream
):
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
event_t2c
.
record
()
tbo_obj
.
tbo_thread_synchronize
(
tid
)
if
not
tbo_one_stream
:
tbo_step_stream
.
wait_event
(
event_t2c
)
else
:
output
=
tensor_model_parallel_all_reduce
(
obj
)
tbo_obj
.
tbo_thread_synchronize
(
tid
)
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
...
...
@@ -218,12 +202,14 @@ def tbo_model_executable(
is_support
=
is_supported_attention_metadata
(
model_input
.
attn_metadata
)
if
not
is_support
:
logger
.
info
(
"tbo:not surpport yet "
,
type
(
model_input
.
attn_metadata
))
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
batch_size
=
len
(
model_input
.
attn_metadata
.
seq_lens
)
is_decode_tbo_invalid
=
not
model_input
.
is_prompt
and
(
envs
.
VLLM_TBO_DECODE_BS
<
2
or
batch_size
<
envs
.
VLLM_TBO_DECODE_BS
or
model_input
.
attn_metadata
.
use_cuda_graph
)
if
batch_size
==
1
or
\
(
not
model_input
.
is_prompt
and
not
enable_tbo_decode
)
or
\
not
is_support
or
\
is_cuda_graph_decode
:
is_decode_tbo_invalid
or
\
not
is_support
:
with
set_forward_context
(
model_input
.
attn_metadata
,
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
...
...
@@ -284,7 +270,7 @@ def tbo_model_executable(
seqlen_agnostic_kwargs
,
model_kwargs_left
,
model_kwargs_right
)
tbo_obj
.
all_reduce
()
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
states_left
,
states_right
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment