Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab4a83b2
Unverified
Commit
ab4a83b2
authored
Sep 05, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 05, 2024
Browse files
Optimize schedule (#1339)
parent
62f15eea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
123 additions
and
8 deletions
+123
-8
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+105
-5
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+18
-3
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
ab4a83b2
...
@@ -108,18 +108,24 @@ class PrefillAdder:
...
@@ -108,18 +108,24 @@ class PrefillAdder:
def
__init__
(
def
__init__
(
self
,
self
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
,
rem_total_tokens
:
int
,
rem_total_tokens
:
int
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
Optional
[
int
],
rem_chunk_tokens
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
mixed_with_decode_tokens
:
int
=
0
,
):
):
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
new_inflight_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
...
@@ -136,16 +142,14 @@ class PrefillAdder:
...
@@ -136,16 +142,14 @@ class PrefillAdder:
)
)
)
)
def
remove_running_tokens
(
def
remove_running_tokens
(
self
,
running_batch
:
ScheduleBatch
):
self
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
):
self
.
rem_total_tokens
-=
sum
(
self
.
rem_total_tokens
-=
sum
(
[
[
min
(
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
CLIP_MAX_NEW_TOKENS
,
)
)
*
new_token_ratio
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
for
r
in
running_batch
.
reqs
]
]
)
)
...
@@ -161,7 +165,29 @@ class PrefillAdder:
...
@@ -161,7 +165,29 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req_ignore_eos
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
if
not
truncated
else
0
),
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
def
add_inflight_req
(
self
,
req
:
Req
):
def
add_inflight_req
(
self
,
req
:
Req
):
if
req
.
sampling_params
.
ignore_eos
:
return
self
.
add_inflight_req_ignore_eos
(
req
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -190,7 +216,81 @@ class PrefillAdder:
...
@@ -190,7 +216,81 @@ class PrefillAdder:
delta
=
self
.
tree_cache
.
dec_lock_ref
(
last_node
)
delta
=
self
.
tree_cache
.
dec_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
self
.
rem_total_tokens
+=
delta
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
get_req_state
(
r
):
new_token_ratio
=
(
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
)
tokens_left
=
r
.
sampling_params
.
max_new_tokens
*
new_token_ratio
-
len
(
r
.
output_ids
)
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
if
tokens_left
>
0
:
return
(
tokens_left
,
tokens_occupied
)
return
None
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
break
else
:
self
.
req_states
.
append
(
state
)
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
new_inflight_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
True
def
add_one_req
(
self
,
req
:
Req
):
def
add_one_req
(
self
,
req
:
Req
):
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
return
self
.
add_one_req_ignore_eos
(
req
)
total_tokens
=
req
.
extend_input_len
+
min
(
total_tokens
=
req
.
extend_input_len
+
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
)
...
@@ -233,4 +333,4 @@ class PrefillAdder:
...
@@ -233,4 +333,4 @@ class PrefillAdder:
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
return
True
return
True
and
not
self
.
no_remaining_tokens
()
python/sglang/srt/managers/tp_worker.py
View file @
ab4a83b2
...
@@ -221,6 +221,7 @@ class ModelTpServer:
...
@@ -221,6 +221,7 @@ class ModelTpServer:
)
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
def
exposed_step
(
self
,
recv_reqs
:
List
):
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
try
:
...
@@ -253,7 +254,13 @@ class ModelTpServer:
...
@@ -253,7 +254,13 @@ class ModelTpServer:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_step
(
self
):
def
forward_step
(
self
):
new_batch
=
self
.
get_new_prefill_batch
()
if
self
.
current_inflight_req
is
not
None
:
self
.
do_not_get_new_batch
=
False
new_batch
=
(
self
.
get_new_prefill_batch
()
if
not
self
.
do_not_get_new_batch
else
None
)
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
# Run a new prefill batch
# Run a new prefill batch
...
@@ -409,6 +416,8 @@ class ModelTpServer:
...
@@ -409,6 +416,8 @@ class ModelTpServer:
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
chunked_prefill_size
,
...
@@ -416,7 +425,7 @@ class ModelTpServer:
...
@@ -416,7 +425,7 @@ class ModelTpServer:
)
)
if
self
.
running_batch
is
not
None
:
if
self
.
running_batch
is
not
None
:
adder
.
remove_running_tokens
(
self
.
running_batch
,
self
.
new_token_ratio
)
adder
.
remove_running_tokens
(
self
.
running_batch
)
has_inflight
=
self
.
current_inflight_req
is
not
None
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
if
self
.
current_inflight_req
is
not
None
:
...
@@ -428,11 +437,12 @@ class ModelTpServer:
...
@@ -428,11 +437,12 @@ class ModelTpServer:
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
or
adder
.
no_remaining_tokens
()
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
):
break
break
...
@@ -700,6 +710,7 @@ class ModelTpServer:
...
@@ -700,6 +710,7 @@ class ModelTpServer:
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
...
@@ -712,6 +723,7 @@ class ModelTpServer:
...
@@ -712,6 +723,7 @@ class ModelTpServer:
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
req
.
output_token_logprobs
.
append
(
...
@@ -720,6 +732,9 @@ class ModelTpServer:
...
@@ -720,6 +732,9 @@ class ModelTpServer:
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
not
has_finished
:
self
.
do_not_get_new_batch
=
True
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment