Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c1e8796
Unverified
Commit
0c1e8796
authored
Oct 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 14, 2024
Browse files
Move filter_batch out of stream_output (#1663)
parent
869f1c02
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
36 deletions
+54
-36
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-12
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+28
-24
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+4
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
0c1e8796
...
...
@@ -659,7 +659,7 @@ class ScheduleBatch:
def
check_for_jump_forward
(
self
,
pad_input_ids_func
):
jump_forward_reqs
=
[]
filter
_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
]
keep
_indices
=
set
(
i
for
i
in
range
(
len
(
self
.
reqs
))
)
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
jump_forward_map
is
not
None
:
...
...
@@ -719,9 +719,9 @@ class ScheduleBatch:
)
jump_forward_reqs
.
append
(
req
)
filter
_indices
.
remove
(
i
)
keep
_indices
.
remove
(
i
)
self
.
filter_batch
(
filter
_indices
)
self
.
filter_batch
(
keep_indices
=
list
(
keep
_indices
)
)
return
jump_forward_reqs
...
...
@@ -740,19 +740,31 @@ class ScheduleBatch:
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
def
filter_batch
(
self
,
current_inflight_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
if
keep_indices
is
None
:
keep_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
current_inflight_req
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
# Filter out all requests
self
.
reqs
=
[]
return
if
len
(
unfinished
_indices
)
==
len
(
self
.
reqs
):
if
len
(
keep
_indices
)
==
len
(
self
.
reqs
):
# No need to filter
return
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished
_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep
_indices
]
new_indices
=
torch
.
tensor
(
unfinished
_indices
,
dtype
=
torch
.
int32
,
device
=
self
.
seq_lens
.
device
keep
_indices
,
dtype
=
torch
.
int32
,
device
=
self
.
seq_lens
.
device
)
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
...
...
@@ -760,16 +772,14 @@ class ScheduleBatch:
self
.
output_ids
=
self
.
output_ids
[
new_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
keep_indices
]
else
:
self
.
top_logprobs_nums
=
None
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_regex
=
any
(
req
.
regex_fsm
for
req
in
self
.
reqs
)
self
.
sampling_info
.
filter_batch
(
unfinished
_indices
,
new_indices
)
self
.
sampling_info
.
filter_batch
(
keep
_indices
,
new_indices
)
def
merge_batch
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
...
...
python/sglang/srt/managers/scheduler.py
View file @
0c1e8796
...
...
@@ -446,31 +446,41 @@ class Scheduler:
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
):
# Merge prefill to the running batch
# Merge
the
prefill
batch in
to the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
running_batch
is
None
:
self
.
running_batch
=
self
.
last_batch
else
:
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
if
self
.
current_inflight_req
:
self
.
last_batch
.
filter_batch
(
self
.
current_inflight_req
)
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
self
.
last_batch
else
:
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
# Prefill first
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
return
new_batch
# Run decode
if
self
.
running_batch
is
not
None
:
self
.
update_running_batch
()
if
not
self
.
running_batch
:
return
None
return
self
.
running_batch
else
:
# Check memory
if
self
.
running_batch
is
None
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
return
# Run decode
before_bs
=
self
.
running_batch
.
batch_size
()
self
.
update_running_batch
()
if
not
self
.
running_batch
:
self
.
batch_is_full
=
False
return
None
if
before_bs
!=
self
.
running_batch
.
batch_size
():
self
.
batch_is_full
=
False
return
self
.
running_batch
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
...
...
@@ -617,6 +627,11 @@ class Scheduler:
global
test_retract
batch
=
self
.
running_batch
batch
.
filter_batch
()
if
batch
.
is_empty
():
self
.
running_batch
=
None
return
# Check if decode out of memory
if
not
batch
.
check_decode_mem
()
or
(
test_retract
and
batch
.
batch_size
()
>
10
):
old_ratio
=
self
.
new_token_ratio
...
...
@@ -640,8 +655,6 @@ class Scheduler:
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
jump_forward_reqs
:
self
.
batch_is_full
=
False
if
batch
.
is_empty
():
self
.
running_batch
=
None
return
...
...
@@ -892,14 +905,8 @@ class Scheduler:
output_no_stop_trim
=
[]
else
:
# embedding or reward model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
else
:
self
.
batch_is_full
=
False
for
req
in
batch
.
reqs
:
if
req
.
finished
()
or
(
req
.
stream
and
(
...
...
@@ -955,9 +962,6 @@ class Scheduler:
}
output_meta_info
.
append
(
meta_info
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
...
...
test/srt/test_json_constrained.py
View file @
0c1e8796
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
"""
import
json
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment