Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36078fb2
Unverified
Commit
36078fb2
authored
Sep 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 17, 2024
Browse files
fix schedule bug (#1450)
parent
b3710d2c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
99 deletions
+59
-99
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+48
-93
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+11
-6
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
36078fb2
...
@@ -119,19 +119,32 @@ class PrefillAdder:
...
@@ -119,19 +119,32 @@ class PrefillAdder:
self
.
running_batch
=
running_batch
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens_
=
self
.
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
cur_rem_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
req_states
=
None
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
new_inflight_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
self
.
log_input_tokens
=
0
if
running_batch
is
not
None
:
# Pre-remove the tokens which will be occupied by the running requests
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
def
no_remaining_tokens
(
self
):
def
no_remaining_tokens
(
self
):
return
(
return
(
self
.
rem_total_tokens
<=
0
self
.
rem_total_tokens
<=
0
...
@@ -141,31 +154,14 @@ class PrefillAdder:
...
@@ -141,31 +154,14 @@ class PrefillAdder:
if
self
.
rem_chunk_tokens
is
not
None
if
self
.
rem_chunk_tokens
is
not
None
else
False
else
False
)
)
)
or
self
.
cur_rem_tokens
<=
0
def
remove_running_tokens
(
self
,
running_batch
:
ScheduleBatch
):
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
self
.
rem_total_tokens_
-=
sum
(
[
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)
for
r
in
running_batch
.
reqs
]
)
)
def
_prefill_one_req
(
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_
total_
tokens
_
-=
extend_input_len
+
max_new_tokens
self
.
cur_
rem_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
self
.
rem_chunk_tokens
-=
extend_input_len
...
@@ -173,29 +169,7 @@ class PrefillAdder:
...
@@ -173,29 +169,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req_ignore_eos
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
if
not
truncated
else
0
),
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
def
add_inflight_req
(
self
,
req
:
Req
):
def
add_inflight_req
(
self
,
req
:
Req
):
if
req
.
sampling_params
.
ignore_eos
:
return
self
.
add_inflight_req_ignore_eos
(
req
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -225,7 +199,7 @@ class PrefillAdder:
...
@@ -225,7 +199,7 @@ class PrefillAdder:
self
.
rem_total_tokens
+=
delta
self
.
rem_total_tokens
+=
delta
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
get
_req_state
(
r
):
def
add
_req_state
(
r
,
insert_sort
=
False
):
new_token_ratio
=
(
new_token_ratio
=
(
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
)
)
...
@@ -235,56 +209,37 @@ class PrefillAdder:
...
@@ -235,56 +209,37 @@ class PrefillAdder:
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
if
tokens_left
>
0
:
if
tokens_left
>
0
:
return
(
tokens_left
,
tokens_occupied
)
if
not
insert_sort
:
self
.
req_states
.
append
((
tokens_left
,
tokens_occupied
))
return
None
else
:
for
i
in
range
(
len
(
self
.
req_states
)):
# Quick Check
if
tokens_left
<=
self
.
req_states
[
i
][
0
]:
can_run
=
False
if
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
<=
self
.
rem_total_tokens
):
can_run
=
True
if
not
can_run
:
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
break
break
else
:
self
.
req_states
.
insert
(
i
,
(
tokens_left
,
tokens_occupied
))
self
.
req_states
.
append
(
state
)
if
self
.
req_states
is
None
:
tokens_freed
=
0
self
.
req_states
=
[]
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
add_req_state
(
req
)
decode_steps
=
(
if
self
.
running_batch
is
not
None
:
self
.
req_states
[
i
+
1
][
0
]
for
r
in
self
.
running_batch
.
reqs
:
if
i
+
1
<
len
(
self
.
req_states
)
add_req_state
(
r
)
else
tokens_left
for
r
in
self
.
can_run_list
:
)
add_req_state
(
r
)
bs
=
len
(
self
.
req_states
)
-
i
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
else
:
return
False
add_req_state
(
req
,
insert_sort
=
True
)
tokens_freed
+=
tokens_occupied
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
cur_rem_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
36078fb2
...
@@ -445,9 +445,6 @@ class ModelTpServer:
...
@@ -445,9 +445,6 @@ class ModelTpServer:
num_mixed_running
,
num_mixed_running
,
)
)
if
self
.
running_batch
is
not
None
:
adder
.
remove_running_tokens
(
self
.
running_batch
)
has_inflight
=
self
.
current_inflight_req
is
not
None
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
self
.
current_inflight_req
.
init_next_round_input
(
...
@@ -465,9 +462,6 @@ class ModelTpServer:
...
@@ -465,9 +462,6 @@ class ModelTpServer:
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
if
(
if
(
self
.
lora_paths
is
not
None
self
.
lora_paths
is
not
None
and
len
(
and
len
(
...
@@ -478,6 +472,10 @@ class ModelTpServer:
...
@@ -478,6 +472,10 @@ class ModelTpServer:
>
self
.
max_loras_per_batch
>
self
.
max_loras_per_batch
):
):
break
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
...
@@ -507,6 +505,11 @@ class ModelTpServer:
...
@@ -507,6 +505,11 @@ class ModelTpServer:
else
:
else
:
tree_cache_hit_rate
=
0.0
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
if
num_mixed_running
>
0
:
logger
.
info
(
logger
.
info
(
f
"Prefill batch"
f
"Prefill batch"
...
@@ -515,6 +518,7 @@ class ModelTpServer:
...
@@ -515,6 +518,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
)
else
:
else
:
...
@@ -524,6 +528,7 @@ class ModelTpServer:
...
@@ -524,6 +528,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment