Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56ffc380
Commit
56ffc380
authored
May 19, 2025
by
lizhigong
Browse files
调试tbo正确性
parent
2a935929
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
57 deletions
+64
-57
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+64
-57
No files found.
vllm/two_batch_overlap/two_batch_overlap.py
View file @
56ffc380
...
@@ -41,6 +41,8 @@ class TwoBatchOverlap():
...
@@ -41,6 +41,8 @@ class TwoBatchOverlap():
self
.
left_first
=
False
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
tbo_running
=
False
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
step_stream
=
torch
.
cuda
.
Stream
()
self
.
step_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
...
@@ -68,46 +70,44 @@ class TwoBatchOverlap():
...
@@ -68,46 +70,44 @@ class TwoBatchOverlap():
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
queue
):
def
thread_two_batch_overlap
(
self
,
queue
):
is_left_thread
=
False
is_left_thread
=
False
tid
=
threading
.
get_ident
()
if
queue
==
self
.
model_input_left_queue
:
if
queue
==
self
.
model_input_left_queue
:
self
.
left_tid
=
t
hreading
.
get_ident
()
self
.
left_tid
=
t
id
is_left_thread
=
True
is_left_thread
=
True
logger
.
info
(
'tbo:new thread %d'
,
self
.
left_tid
)
logger
.
info
(
'tbo:new thread %d'
,
self
.
left_tid
)
init_tbo_forward_context
(
True
,
self
.
left_tid
)
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
else
:
self
.
right_tid
=
t
hreading
.
get_ident
()
self
.
right_tid
=
t
id
logger
.
info
(
'tbo:new thread %d'
,
self
.
right_tid
)
logger
.
info
(
'tbo:new thread %d'
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
while
True
:
with
torch
.
cuda
.
stream
(
self
.
step_stream
):
model_input
=
queue
.
get
()
while
True
:
if
model_input
==
None
:
model_input
=
queue
.
get
()
break
if
model_input
==
None
:
profile
.
ProfRangePush
(
'start'
)
break
self
.
tbo_thread_synchronize
(
False
)
profile
.
ProfRangePush
(
'start'
)
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
tbo_thread_synchronize
(
tid
)
self
.
vllm_config
,
self
.
virtual_engine
):
with
set_forward_context
(
model_input
.
attn_metadata
,
hidden_or_intermediate_states
=
self
.
model_executable
(
self
.
vllm_config
,
self
.
virtual_engine
):
input_ids
=
model_input
.
input_tokens
,
hidden_or_intermediate_states
=
self
.
model_executable
(
positions
=
model_input
.
input_positions
,
input_ids
=
model_input
.
input_tokens
,
intermediate_tensors
=
self
.
intermediate_tensors
,
positions
=
model_input
.
input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
intermediate_tensors
=
self
.
intermediate_tensors
,
device
=
self
.
self_device
),
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
**
self
.
seqlen_agnostic_kwargs
,
device
=
self
.
self_device
),
**
self
.
model_kwargs
,
**
self
.
seqlen_agnostic_kwargs
,
)
**
self
.
model_kwargs
,
profile
.
ProfRangePush
(
'end'
)
)
if
is_left_thread
:
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
hidden_or_intermediate_states
)
self
.
states_left_queue
.
put
(
hidden_or_intermediate_states
)
else
:
else
:
self
.
all_reduce_queue
.
put
(
None
)
self
.
all_reduce_queue
.
put
(
None
)
self
.
states_right_queue
.
put
(
hidden_or_intermediate_states
)
self
.
states_right_queue
.
put
(
hidden_or_intermediate_states
)
profile
.
ProfRangePop
()
def
tbo_thread_synchronize
(
self
,
recode_flag
=
True
):
def
tbo_thread_synchronize
(
self
,
tid
):
tid
=
threading
.
get_ident
()
if
tid
==
self
.
left_tid
:
if
tid
==
self
.
left_tid
:
if
recode_flag
and
not
tbo_one_stream
:
print
(
'###left_c2t_recorded'
)
self
.
event_left_c2t
.
record
()
if
not
self
.
left_first
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
profile
.
ProfRangePop
()
profile
.
ProfRangePop
()
...
@@ -116,9 +116,6 @@ class TwoBatchOverlap():
...
@@ -116,9 +116,6 @@ class TwoBatchOverlap():
self
.
left_first
=
False
self
.
left_first
=
False
return
self
.
event_left_c2t
,
self
.
event_left_t2c
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
else
:
if
recode_flag
and
not
tbo_one_stream
:
print
(
'###right_c2t_recorded'
)
self
.
event_right_c2t
.
record
()
self
.
sem_left
.
release
()
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
self
.
sem_right
.
acquire
()
...
@@ -160,17 +157,14 @@ class TwoBatchOverlap():
...
@@ -160,17 +157,14 @@ class TwoBatchOverlap():
if
obj
==
None
:
if
obj
==
None
:
break
break
buf
,
event_c2t
,
event_t2c
=
obj
buf
,
event_c2t
,
event_t2c
=
obj
#print('###buf', buf[0,0:5])
if
tbo_one_stream
:
if
tbo_one_stream
:
output
=
tensor_model_parallel_all_reduce
(
buf
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
else
:
else
:
event_c2t
.
record
()
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
print
(
'###stream.wait_event event_c2t before all_reduce'
)
self
.
stream
.
wait_event
(
event_c2t
)
self
.
stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
print
(
'###event_t2c recorded'
)
event_t2c
.
record
()
event_t2c
.
record
()
#print('###print', output[0,0:5])
self
.
all_reduce_out
.
put
(
output
)
self
.
all_reduce_out
.
put
(
output
)
tbo_obj
=
None
tbo_obj
=
None
...
@@ -189,13 +183,17 @@ def finish_two_batch_overlap():
...
@@ -189,13 +183,17 @@ def finish_two_batch_overlap():
def
tbo_all_reduce
(
obj
):
def
tbo_all_reduce
(
obj
):
if
enable_tbo
and
tbo_obj
!=
None
and
tbo_obj
.
tbo_running
:
if
enable_tbo
and
tbo_obj
!=
None
and
tbo_obj
.
tbo_running
:
event_c2t
,
event_t2c
=
tbo_obj
.
tbo_thread_synchronize
()
tid
=
threading
.
get_ident
()
if
not
tbo_one_stream
:
if
tid
==
tbo_obj
.
left_tid
:
event_c2t
,
event_t2c
=
tbo_obj
.
event_left_c2t
,
tbo_obj
.
event_left_t2c
else
:
event_c2t
,
event_t2c
=
tbo_obj
.
event_right_c2t
,
tbo_obj
.
event_right_t2c
tbo_obj
.
all_reduce_queue
.
put
([
obj
,
event_c2t
,
event_t2c
])
tbo_obj
.
all_reduce_queue
.
put
([
obj
,
event_c2t
,
event_t2c
])
output
=
tbo_obj
.
all_reduce_out
.
get
()
output
=
tbo_obj
.
all_reduce_out
.
get
()
tbo_obj
.
tbo_thread_synchronize
(
tid
)
if
not
tbo_one_stream
:
if
not
tbo_one_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
tbo_obj
.
step_stream
.
wait_event
(
event_t2c
)
print
(
'###current_stream wait event event_t2c'
)
current_stream
.
wait_event
(
event_t2c
)
return
output
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
return
tensor_model_parallel_all_reduce
(
obj
)
...
@@ -420,6 +418,7 @@ def tbo_model_executable(
...
@@ -420,6 +418,7 @@ def tbo_model_executable(
seqlen_agnostic_kwargs
,
seqlen_agnostic_kwargs
,
model_kwargs
,
model_kwargs
,
):
):
profile
.
ProfRangePush
(
'tbo_model_executable'
)
init_two_batch_overlap
()
init_two_batch_overlap
()
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
...
@@ -446,20 +445,28 @@ def tbo_model_executable(
...
@@ -446,20 +445,28 @@ def tbo_model_executable(
batch_size_right
=
batch_size_left
batch_size_right
=
batch_size_left
if
batch_size
%
2
==
1
:
if
batch_size
%
2
==
1
:
batch_size_right
+=
1
batch_size_right
+=
1
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
tbo_obj
.
set_model_input
(
model_input_left
,
tbo_obj
.
step_event
.
record
()
model_input_right
,
current_stream
=
torch
.
cuda
.
current_stream
()
vllm_config
,
with
torch
.
cuda
.
stream
(
tbo_obj
.
step_stream
):
virtual_engine
,
tbo_obj
.
step_stream
.
wait_event
(
tbo_obj
.
step_event
)
model_executable
,
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
intermediate_tensors
,
tbo_obj
.
set_model_input
(
model_input_left
,
multi_modal_kwargs
,
model_input_right
,
self_device
,
vllm_config
,
seqlen_agnostic_kwargs
,
virtual_engine
,
model_kwargs
)
model_executable
,
tbo_obj
.
all_reduce
()
intermediate_tensors
,
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
multi_modal_kwargs
,
self_device
,
seqlen_agnostic_kwargs
,
model_kwargs
)
tbo_obj
.
all_reduce
()
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
states_left
,
states_right
)
hidden_or_intermediate_states
=
merge_model_output
(
states_left
,
states_right
)
tbo_obj
.
tbo_running
=
False
tbo_obj
.
tbo_running
=
False
tbo_obj
.
step_event
.
record
()
current_stream
.
wait_event
(
tbo_obj
.
step_event
)
profile
.
ProfRangePop
()
return
hidden_or_intermediate_states
return
hidden_or_intermediate_states
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment