Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2486eb5
Unverified
Commit
a2486eb5
authored
Dec 08, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 08, 2024
Browse files
Fix a bug with logprob streaming + chunked prefill (#2403)
parent
61dec545
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
13 deletions
+24
-13
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+8
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+15
-12
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
No files found.
python/sglang/bench_serving.py
View file @
a2486eb5
...
@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
...
@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
},
},
"stream"
:
not
args
.
disable_stream
,
"stream"
:
not
args
.
disable_stream
,
"lora_path"
:
request_func_input
.
lora_name
,
"lora_path"
:
request_func_input
.
lora_name
,
"return_logprob"
:
args
.
return_logprob
,
"logprob_start_len"
:
-
1
,
**
request_func_input
.
extra_request_body
,
**
request_func_input
.
extra_request_body
,
}
}
headers
=
{}
headers
=
{}
...
@@ -911,7 +913,7 @@ async def benchmark(
...
@@ -911,7 +913,7 @@ async def benchmark(
prompt
=
test_prompt
,
prompt
=
test_prompt
,
api_url
=
api_url
,
api_url
=
api_url
,
prompt_len
=
test_prompt_len
,
prompt_len
=
test_prompt_len
,
output_len
=
test_output_len
,
output_len
=
min
(
test_output_len
,
32
),
lora_name
=
lora_name
,
lora_name
=
lora_name
,
extra_request_body
=
extra_request_body
,
extra_request_body
=
extra_request_body
,
)
)
...
@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
...
@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable ignoring EOS."
,
help
=
"Disable ignoring EOS."
,
)
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
,
help
=
"Return logprob."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--extra-request-body"
,
"--extra-request-body"
,
metavar
=
'{"key1": "value1", "key2": "value2"}'
,
metavar
=
'{"key1": "value1", "key2": "value2"}'
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
a2486eb5
...
@@ -440,16 +440,11 @@ class Scheduler:
...
@@ -440,16 +440,11 @@ class Scheduler:
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
recv_reqs
=
[]
recv_reqs
=
[]
if
self
.
last_batch
is
None
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
()
recv_reqs
.
append
(
recv_req
)
else
:
while
True
:
while
True
:
try
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
except
zmq
.
ZMQError
:
break
break
recv_reqs
.
append
(
recv_req
)
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
...
@@ -949,6 +944,7 @@ class Scheduler:
...
@@ -949,6 +944,7 @@ class Scheduler:
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
skip_stream_req
=
None
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
...
@@ -1005,6 +1001,10 @@ class Scheduler:
...
@@ -1005,6 +1001,10 @@ class Scheduler:
else
:
else
:
# being chunked reqs' prefill is not finished
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
req
.
is_being_chunked
-=
1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req
=
req
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
...
@@ -1034,7 +1034,7 @@ class Scheduler:
...
@@ -1034,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
req
.
is_being_chunked
-=
1
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
...
@@ -1179,7 +1179,7 @@ class Scheduler:
...
@@ -1179,7 +1179,7 @@ class Scheduler:
return
num_input_logprobs
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
]):
def
stream_output
(
self
,
reqs
:
List
[
Req
]
,
skip_req
:
Optional
[
Req
]
=
None
):
"""Stream the output to detokenizer."""
"""Stream the output to detokenizer."""
output_rids
=
[]
output_rids
=
[]
output_meta_info
:
List
[
dict
]
=
[]
output_meta_info
:
List
[
dict
]
=
[]
...
@@ -1199,6 +1199,9 @@ class Scheduler:
...
@@ -1199,6 +1199,9 @@ class Scheduler:
is_stream_iter
=
self
.
forward_ct_decode
%
self
.
stream_interval
==
0
is_stream_iter
=
self
.
forward_ct_decode
%
self
.
stream_interval
==
0
for
req
in
reqs
:
for
req
in
reqs
:
if
req
is
skip_req
:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
# TODO(lianmin): revisit this for overlap + retract + stream
if
req
.
finished
()
or
(
if
req
.
finished
()
or
(
req
.
stream
and
(
is_stream_iter
or
len
(
req
.
output_ids
)
==
1
)
req
.
stream
and
(
is_stream_iter
or
len
(
req
.
output_ids
)
==
1
)
...
...
python/sglang/test/test_utils.py
View file @
a2486eb5
...
@@ -568,6 +568,7 @@ def run_bench_serving(
...
@@ -568,6 +568,7 @@ def run_bench_serving(
disable_tqdm
=
False
,
disable_tqdm
=
False
,
disable_stream
=
disable_stream
,
disable_stream
=
disable_stream
,
disable_ignore_eos
=
False
,
disable_ignore_eos
=
False
,
return_logprob
=
False
,
lora_name
=
None
,
lora_name
=
None
,
extra_request_body
=
None
,
extra_request_body
=
None
,
profile
=
None
,
profile
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment