Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ce36068
Unverified
Commit
7ce36068
authored
Oct 21, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 21, 2024
Browse files
Faster overlap mode scheduler (#1738)
parent
efb099cd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
5 deletions
+25
-5
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+25
-5
No files found.
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
7ce36068
...
@@ -55,7 +55,7 @@ class TpModelWorkerClient:
...
@@ -55,7 +55,7 @@ class TpModelWorkerClient:
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
# Launch
a
thread
# Launch thread
s
self
.
input_queue
=
Queue
()
self
.
input_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
...
@@ -64,6 +64,12 @@ class TpModelWorkerClient:
...
@@ -64,6 +64,12 @@ class TpModelWorkerClient:
)
)
self
.
forward_thread
.
start
()
self
.
forward_thread
.
start
()
self
.
copy_queue
=
Queue
()
self
.
copy_thread
=
threading
.
Thread
(
target
=
self
.
copy_thread_func
,
)
self
.
copy_thread
.
start
()
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
return
self
.
worker
.
get_worker_info
()
...
@@ -86,7 +92,10 @@ class TpModelWorkerClient:
...
@@ -86,7 +92,10 @@ class TpModelWorkerClient:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
while
True
:
while
True
:
self
.
has_inflight_batch
=
False
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
self
.
has_inflight_batch
=
True
self
.
launch_event
=
threading
.
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
input_ids
=
model_worker_batch
.
input_ids
input_ids
=
model_worker_batch
.
input_ids
...
@@ -100,6 +109,7 @@ class TpModelWorkerClient:
...
@@ -100,6 +109,7 @@ class TpModelWorkerClient:
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
self
.
launch_event
.
set
()
# Update the future token ids map
# Update the future token ids map
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
...
@@ -113,13 +123,23 @@ class TpModelWorkerClient:
...
@@ -113,13 +123,23 @@ class TpModelWorkerClient:
torch
.
int32
torch
.
int32
)
)
# Set the result
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
copy_event
=
torch
.
cuda
.
Event
(
blocking
=
True
)
assert
logits_output
.
next_token_logprobs
is
None
,
"Not supported"
copy_event
.
record
()
self
.
output_queue
.
put
((
None
,
next_token_ids
))
self
.
copy_queue
.
put
((
copy_event
,
next_token_ids
))
def
copy_thread_func
(
self
):
while
True
:
copy_event
,
next_token_ids
=
self
.
copy_queue
.
get
()
while
not
copy_event
.
query
():
time
.
sleep
(
1e-5
)
self
.
output_queue
.
put
((
None
,
next_token_ids
.
tolist
()))
def
resulve_batch_result
(
self
,
bid
:
int
):
def
resulve_batch_result
(
self
,
bid
:
int
):
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
if
self
.
has_inflight_batch
:
# Wait until the batch is launched
self
.
launch_event
.
wait
()
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment