Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36078fb2
Unverified
Commit
36078fb2
authored
Sep 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 17, 2024
Browse files
fix schedule bug (#1450)
parent
b3710d2c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
99 deletions
+59
-99
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+48
-93
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+11
-6
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
36078fb2
...
...
@@ -119,19 +119,32 @@ class PrefillAdder:
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens_
=
self
.
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
cur_rem_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
if
running_batch
is
not
None
:
# Pre-remove the tokens which will be occupied by the running requests
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
def
no_remaining_tokens
(
self
):
return
(
self
.
rem_total_tokens
<=
0
...
...
@@ -141,31 +154,14 @@ class PrefillAdder:
if
self
.
rem_chunk_tokens
is
not
None
else
False
)
)
def
remove_running_tokens
(
self
,
running_batch
:
ScheduleBatch
):
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
self
.
rem_total_tokens_
-=
sum
(
[
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)
for
r
in
running_batch
.
reqs
]
or
self
.
cur_rem_tokens
<=
0
)
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_
total_
tokens
_
-=
extend_input_len
+
max_new_tokens
self
.
cur_
rem_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
...
...
@@ -173,29 +169,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req_ignore_eos
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
if
not
truncated
else
0
),
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
def
add_inflight_req
(
self
,
req
:
Req
):
if
req
.
sampling_params
.
ignore_eos
:
return
self
.
add_inflight_req_ignore_eos
(
req
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
...
@@ -225,7 +199,7 @@ class PrefillAdder:
self
.
rem_total_tokens
+=
delta
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
get
_req_state
(
r
):
def
add
_req_state
(
r
,
insert_sort
=
False
):
new_token_ratio
=
(
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
)
...
...
@@ -235,56 +209,37 @@ class PrefillAdder:
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
if
tokens_left
>
0
:
return
(
tokens_left
,
tokens_occupied
)
return
None
# Quick Check
can_run
=
False
if
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
<=
self
.
rem_total_tokens
):
can_run
=
True
if
not
can_run
:
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
if
not
insert_sort
:
self
.
req_states
.
append
((
tokens_left
,
tokens_occupied
))
else
:
for
i
in
range
(
len
(
self
.
req_states
)):
if
tokens_left
<=
self
.
req_states
[
i
][
0
]:
break
else
:
self
.
req_states
.
append
(
state
)
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
self
.
req_states
.
insert
(
i
,
(
tokens_left
,
tokens_occupied
))
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
add_req_state
(
req
)
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
add_req_state
(
r
)
for
r
in
self
.
can_run_list
:
add_req_state
(
r
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
add_req_state
(
req
,
insert_sort
=
True
)
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
cur_rem_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
36078fb2
...
...
@@ -445,9 +445,6 @@ class ModelTpServer:
num_mixed_running
,
)
if
self
.
running_batch
is
not
None
:
adder
.
remove_running_tokens
(
self
.
running_batch
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
...
...
@@ -465,9 +462,6 @@ class ModelTpServer:
)
for
req
in
self
.
waiting_queue
:
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
if
(
self
.
lora_paths
is
not
None
and
len
(
...
...
@@ -478,6 +472,10 @@ class ModelTpServer:
>
self
.
max_loras_per_batch
):
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
@@ -507,6 +505,11 @@ class ModelTpServer:
else
:
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
logger
.
info
(
f
"Prefill batch"
...
...
@@ -515,6 +518,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
...
...
@@ -524,6 +528,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment