Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
828aeaae
"vscode:/vscode.git/clone" did not exist on "aecdff1869c6ae2c923e9d6f164d20f3dd917bcf"
Commit
828aeaae
authored
May 19, 2025
by
lizhigong
Browse files
优化stream的初始化和warmup方式
parent
56ffc380
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
10 deletions
+16
-10
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+16
-10
No files found.
vllm/two_batch_overlap/two_batch_overlap.py
View file @
828aeaae
...
@@ -24,8 +24,13 @@ logger = init_logger(__name__)
...
@@ -24,8 +24,13 @@ logger = init_logger(__name__)
def
is_enable_tbo
():
def
is_enable_tbo
():
return
enable_tbo
return
enable_tbo
tbo_step_stream
=
None
all_reduce_stream
=
None
class
TwoBatchOverlap
():
class
TwoBatchOverlap
():
def
__init__
(
self
):
def
__init__
(
self
):
global
tbo_step_stream
global
all_reduce_stream
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
...
@@ -40,8 +45,9 @@ class TwoBatchOverlap():
...
@@ -40,8 +45,9 @@ class TwoBatchOverlap():
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_running
=
False
self
.
stream
=
torch
.
cuda
.
Stream
()
if
tbo_step_stream
==
None
:
self
.
step_stream
=
torch
.
cuda
.
Stream
()
tbo_step_stream
=
torch
.
cuda
.
Stream
()
all_reduce_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
...
@@ -80,7 +86,7 @@ class TwoBatchOverlap():
...
@@ -80,7 +86,7 @@ class TwoBatchOverlap():
self
.
right_tid
=
tid
self
.
right_tid
=
tid
logger
.
info
(
'tbo:new thread %d'
,
self
.
right_tid
)
logger
.
info
(
'tbo:new thread %d'
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
self
.
step_stream
):
with
torch
.
cuda
.
stream
(
tbo_
step_stream
):
while
True
:
while
True
:
model_input
=
queue
.
get
()
model_input
=
queue
.
get
()
if
model_input
==
None
:
if
model_input
==
None
:
...
@@ -161,8 +167,8 @@ class TwoBatchOverlap():
...
@@ -161,8 +167,8 @@ class TwoBatchOverlap():
output
=
tensor_model_parallel_all_reduce
(
buf
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
else
:
else
:
event_c2t
.
record
()
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
all_reduce_
stream
):
self
.
stream
.
wait_event
(
event_c2t
)
all_reduce_
stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
event_t2c
.
record
()
event_t2c
.
record
()
self
.
all_reduce_out
.
put
(
output
)
self
.
all_reduce_out
.
put
(
output
)
...
@@ -193,7 +199,7 @@ def tbo_all_reduce(obj):
...
@@ -193,7 +199,7 @@ def tbo_all_reduce(obj):
output
=
tbo_obj
.
all_reduce_out
.
get
()
output
=
tbo_obj
.
all_reduce_out
.
get
()
tbo_obj
.
tbo_thread_synchronize
(
tid
)
tbo_obj
.
tbo_thread_synchronize
(
tid
)
if
not
tbo_one_stream
:
if
not
tbo_one_stream
:
tbo_
obj
.
step_stream
.
wait_event
(
event_t2c
)
tbo_step_stream
.
wait_event
(
event_t2c
)
return
output
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
return
tensor_model_parallel_all_reduce
(
obj
)
...
@@ -418,7 +424,6 @@ def tbo_model_executable(
...
@@ -418,7 +424,6 @@ def tbo_model_executable(
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
model_kwargs
,
model_kwargs
,
):
):
profile
.
ProfRangePush
(
'tbo_model_executable'
)
init_two_batch_overlap
()
init_two_batch_overlap
()
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
...
@@ -439,6 +444,7 @@ def tbo_model_executable(
...
@@ -439,6 +444,7 @@ def tbo_model_executable(
**
model_kwargs
,
**
model_kwargs
,
)
)
return
hidden_or_intermediate_states
return
hidden_or_intermediate_states
profile
.
ProfRangePush
(
'tbo_model_executable'
)
tbo_obj
.
tbo_running
=
True
tbo_obj
.
tbo_running
=
True
tbo_obj
.
left_first
=
True
tbo_obj
.
left_first
=
True
batch_size_left
=
int
(
batch_size
/
2
)
batch_size_left
=
int
(
batch_size
/
2
)
...
@@ -446,11 +452,11 @@ def tbo_model_executable(
...
@@ -446,11 +452,11 @@ def tbo_model_executable(
if
batch_size
%
2
==
1
:
if
batch_size
%
2
==
1
:
batch_size_right
+=
1
batch_size_right
+=
1
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
tbo_obj
.
step_event
.
record
()
tbo_obj
.
step_event
.
record
()
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
tbo_obj
.
step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
tbo_obj
.
step_stream
.
wait_event
(
tbo_obj
.
step_event
)
tbo_step_stream
.
wait_event
(
tbo_obj
.
step_event
)
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
tbo_obj
.
set_model_input
(
model_input_left
,
tbo_obj
.
set_model_input
(
model_input_left
,
model_input_right
,
model_input_right
,
vllm_config
,
vllm_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment