Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
05bea688
Unverified
Commit
05bea688
authored
Sep 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 07, 2024
Browse files
Fix some online scheduling delay (#1345)
parent
ab4a83b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
36 deletions
+52
-36
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+47
-30
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-6
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
05bea688
...
...
@@ -119,6 +119,7 @@ class PrefillAdder:
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens_
=
self
.
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
...
...
@@ -153,11 +154,18 @@ class PrefillAdder:
for
r
in
running_batch
.
reqs
]
)
self
.
rem_total_tokens_
-=
sum
(
[
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)
for
r
in
running_batch
.
reqs
]
)
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_total_tokens_
-=
extend_input_len
+
max_new_tokens
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
...
...
@@ -231,43 +239,52 @@ class PrefillAdder:
return
None
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
# Quick Check
can_run
=
False
if
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
<=
self
.
rem_total_tokens
):
can_run
=
True
if
not
can_run
:
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
break
else
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
break
else
:
self
.
req_states
.
append
(
state
)
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
05bea688
...
...
@@ -231,6 +231,7 @@ class ModelTpServer:
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
...
...
@@ -254,12 +255,10 @@ class ModelTpServer:
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
self
.
current_inflight_req
is
not
None
:
self
.
do_not_get_new_batch
=
False
new_batch
=
(
self
.
get_new_prefill_batch
()
if
not
self
.
do_not_get_new_batch
else
None
)
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment