Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
05bea688
"..._static/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "ebbd5f643d3006c601183e6f5a111611663754c5"
Unverified
Commit
05bea688
authored
Sep 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 07, 2024
Browse files
Fix some online scheduling delay (#1345)
parent
ab4a83b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
36 deletions
+52
-36
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+47
-30
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-6
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
05bea688
...
@@ -119,6 +119,7 @@ class PrefillAdder:
...
@@ -119,6 +119,7 @@ class PrefillAdder:
self
.
running_batch
=
running_batch
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens_
=
self
.
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
...
@@ -153,11 +154,18 @@ class PrefillAdder:
...
@@ -153,11 +154,18 @@ class PrefillAdder:
for
r
in
running_batch
.
reqs
for
r
in
running_batch
.
reqs
]
]
)
)
self
.
rem_total_tokens_
-=
sum
(
[
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)
for
r
in
running_batch
.
reqs
]
)
def
_prefill_one_req
(
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_total_tokens_
-=
extend_input_len
+
max_new_tokens
self
.
rem_input_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
self
.
rem_chunk_tokens
-=
extend_input_len
...
@@ -231,43 +239,52 @@ class PrefillAdder:
...
@@ -231,43 +239,52 @@ class PrefillAdder:
return
None
return
None
if
self
.
req_states
is
None
:
# Quick Check
self
.
req_states
=
[]
can_run
=
False
if
self
.
running_batch
is
not
None
:
if
(
for
r
in
self
.
running_batch
.
reqs
:
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
<=
self
.
rem_total_tokens
):
can_run
=
True
if
not
can_run
:
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
state
=
get_req_state
(
r
)
if
state
is
not
None
:
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
req
)
state
=
get_req_state
(
r
)
if
state
is
not
None
:
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
else
:
state
=
get_req_state
(
req
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
self
.
req_states
.
insert
(
i
,
state
)
break
break
else
:
else
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
append
(
state
)
tokens_freed
=
0
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
else
tokens_left
)
)
bs
=
len
(
self
.
req_states
)
-
i
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
return
False
tokens_freed
+=
tokens_occupied
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
05bea688
...
@@ -231,6 +231,7 @@ class ModelTpServer:
...
@@ -231,6 +231,7 @@ class ModelTpServer:
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
):
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
elif
isinstance
(
recv_req
,
AbortReq
):
...
@@ -254,12 +255,10 @@ class ModelTpServer:
...
@@ -254,12 +255,10 @@ class ModelTpServer:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_step
(
self
):
def
forward_step
(
self
):
if
self
.
current_inflight_req
is
not
None
:
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
self
.
do_not_get_new_batch
=
False
new_batch
=
None
else
:
new_batch
=
(
new_batch
=
self
.
get_new_prefill_batch
()
self
.
get_new_prefill_batch
()
if
not
self
.
do_not_get_new_batch
else
None
)
self
.
do_not_get_new_batch
=
False
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment