Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
36078fb2
"...model/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "d09464be06abdf0a2f108deae99afd3219065267"
Unverified
Commit
36078fb2
authored
Sep 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 17, 2024
Browse files
fix schedule bug (#1450)
parent
b3710d2c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
99 deletions
+59
-99
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+48
-93
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+11
-6
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
36078fb2
...
@@ -119,19 +119,32 @@ class PrefillAdder:
...
@@ -119,19 +119,32 @@ class PrefillAdder:
self
.
running_batch
=
running_batch
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_total_tokens_
=
self
.
rem_total_tokens
self
.
total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
cur_rem_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
req_states
=
None
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
new_inflight_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
self
.
log_input_tokens
=
0
if
running_batch
is
not
None
:
# Pre-remove the tokens which will be occupied by the running requests
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
def
no_remaining_tokens
(
self
):
def
no_remaining_tokens
(
self
):
return
(
return
(
self
.
rem_total_tokens
<=
0
self
.
rem_total_tokens
<=
0
...
@@ -141,31 +154,14 @@ class PrefillAdder:
...
@@ -141,31 +154,14 @@ class PrefillAdder:
if
self
.
rem_chunk_tokens
is
not
None
if
self
.
rem_chunk_tokens
is
not
None
else
False
else
False
)
)
)
or
self
.
cur_rem_tokens
<=
0
def
remove_running_tokens
(
self
,
running_batch
:
ScheduleBatch
):
self
.
rem_total_tokens
-=
sum
(
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
self
.
rem_total_tokens_
-=
sum
(
[
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)
for
r
in
running_batch
.
reqs
]
)
)
def
_prefill_one_req
(
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_
total_
tokens
_
-=
extend_input_len
+
max_new_tokens
self
.
cur_
rem_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
self
.
rem_chunk_tokens
-=
extend_input_len
...
@@ -173,29 +169,7 @@ class PrefillAdder:
...
@@ -173,29 +169,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req_ignore_eos
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
if
not
truncated
else
0
),
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
def
add_inflight_req
(
self
,
req
:
Req
):
def
add_inflight_req
(
self
,
req
:
Req
):
if
req
.
sampling_params
.
ignore_eos
:
return
self
.
add_inflight_req_ignore_eos
(
req
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -225,7 +199,7 @@ class PrefillAdder:
...
@@ -225,7 +199,7 @@ class PrefillAdder:
self
.
rem_total_tokens
+=
delta
self
.
rem_total_tokens
+=
delta
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
add_one_req_ignore_eos
(
self
,
req
:
Req
):
def
get
_req_state
(
r
):
def
add
_req_state
(
r
,
insert_sort
=
False
):
new_token_ratio
=
(
new_token_ratio
=
(
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
1.0
if
r
.
sampling_params
.
ignore_eos
else
self
.
new_token_ratio
)
)
...
@@ -235,56 +209,37 @@ class PrefillAdder:
...
@@ -235,56 +209,37 @@ class PrefillAdder:
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
tokens_occupied
=
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
if
tokens_left
>
0
:
if
tokens_left
>
0
:
return
(
tokens_left
,
tokens_occupied
)
if
not
insert_sort
:
self
.
req_states
.
append
((
tokens_left
,
tokens_occupied
))
return
None
else
:
for
i
in
range
(
len
(
self
.
req_states
)):
# Quick Check
if
tokens_left
<=
self
.
req_states
[
i
][
0
]:
can_run
=
False
if
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
<=
self
.
rem_total_tokens
):
can_run
=
True
if
not
can_run
:
if
self
.
req_states
is
None
:
self
.
req_states
=
[]
if
self
.
running_batch
is
not
None
:
for
r
in
self
.
running_batch
.
reqs
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
for
r
in
self
.
can_run_list
:
state
=
get_req_state
(
r
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
state
=
get_req_state
(
req
)
if
state
is
not
None
:
self
.
req_states
.
append
(
state
)
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
else
:
state
=
get_req_state
(
req
)
if
state
is
not
None
:
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
if
tokens_left
>=
state
[
0
]:
self
.
req_states
.
insert
(
i
,
state
)
break
break
else
:
self
.
req_states
.
insert
(
i
,
(
tokens_left
,
tokens_occupied
))
self
.
req_states
.
append
(
state
)
if
self
.
req_states
is
None
:
tokens_freed
=
0
self
.
req_states
=
[]
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
add_req_state
(
req
)
decode_steps
=
(
if
self
.
running_batch
is
not
None
:
self
.
req_states
[
i
+
1
][
0
]
for
r
in
self
.
running_batch
.
reqs
:
if
i
+
1
<
len
(
self
.
req_states
)
add_req_state
(
r
)
else
tokens_left
for
r
in
self
.
can_run_list
:
)
add_req_state
(
r
)
bs
=
len
(
self
.
req_states
)
-
i
self
.
req_states
.
sort
(
key
=
lambda
x
:
x
[
0
])
if
self
.
total_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
else
:
return
False
add_req_state
(
req
,
insert_sort
=
True
)
tokens_freed
+=
tokens_occupied
tokens_freed
=
0
for
i
,
(
tokens_left
,
tokens_occupied
)
in
enumerate
(
self
.
req_states
):
decode_steps
=
(
self
.
req_states
[
i
+
1
][
0
]
if
i
+
1
<
len
(
self
.
req_states
)
else
tokens_left
)
bs
=
len
(
self
.
req_states
)
-
i
if
self
.
cur_rem_tokens
+
tokens_freed
-
decode_steps
*
bs
<=
0
:
return
False
tokens_freed
+=
tokens_occupied
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
if
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
:
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
36078fb2
...
@@ -445,9 +445,6 @@ class ModelTpServer:
...
@@ -445,9 +445,6 @@ class ModelTpServer:
num_mixed_running
,
num_mixed_running
,
)
)
if
self
.
running_batch
is
not
None
:
adder
.
remove_running_tokens
(
self
.
running_batch
)
has_inflight
=
self
.
current_inflight_req
is
not
None
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
self
.
current_inflight_req
.
init_next_round_input
(
...
@@ -465,9 +462,6 @@ class ModelTpServer:
...
@@ -465,9 +462,6 @@ class ModelTpServer:
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
if
(
if
(
self
.
lora_paths
is
not
None
self
.
lora_paths
is
not
None
and
len
(
and
len
(
...
@@ -478,6 +472,10 @@ class ModelTpServer:
...
@@ -478,6 +472,10 @@ class ModelTpServer:
>
self
.
max_loras_per_batch
>
self
.
max_loras_per_batch
):
):
break
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
...
@@ -507,6 +505,11 @@ class ModelTpServer:
...
@@ -507,6 +505,11 @@ class ModelTpServer:
else
:
else
:
tree_cache_hit_rate
=
0.0
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
if
num_mixed_running
>
0
:
logger
.
info
(
logger
.
info
(
f
"Prefill batch"
f
"Prefill batch"
...
@@ -515,6 +518,7 @@ class ModelTpServer:
...
@@ -515,6 +518,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
)
else
:
else
:
...
@@ -524,6 +528,7 @@ class ModelTpServer:
...
@@ -524,6 +528,7 @@ class ModelTpServer:
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment