Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1e495e08
Unverified
Commit
1e495e08
authored
Sep 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 03, 2024
Browse files
[Fix] Fix select by ensuring each request has at least one token (#1318)
parent
12cb115d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
3 deletions
+120
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-3
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+68
-0
python/sglang/utils.py
python/sglang/utils.py
+39
-0
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+7
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
1e495e08
...
...
@@ -178,19 +178,22 @@ class Req:
def
adjust_max_prefix_ids
(
self
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
fill_ids
)
max_prefix_len
=
input_len
# FIXME: To work around some bugs in logprob computation, we need to ensure each
# request has at least one token. Later, we can relax this requirement and use `input_len`.
max_prefix_len
=
input_len
-
1
if
self
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
1
)
if
self
.
return_logprob
:
max_prefix_len
=
min
(
max_prefix_len
,
self
.
logprob_start_len
)
if
self
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
2
)
max_prefix_len
=
min
(
max_prefix_len
,
self
.
logprob_start_len
)
max_prefix_len
=
max
(
max_prefix_len
,
0
)
return
self
.
fill_ids
[:
max_prefix_len
]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
...
...
python/sglang/test/test_programs.py
View file @
1e495e08
...
...
@@ -2,8 +2,12 @@
import
json
import
re
import
time
import
numpy
as
np
import
sglang
as
sgl
from
sglang.utils
import
fetch_and_cache_jsonl
def
test_few_shot_qa
():
...
...
@@ -447,3 +451,67 @@ def test_chat_completion_speculative():
)
gen_character_spec
().
sync
()
def
test_hellaswag_select
():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url
=
"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines
=
fetch_and_cache_jsonl
(
url
)
# Construct prompts
def
get_one_example
(
lines
,
i
,
include_answer
):
ret
=
lines
[
i
][
"activity_label"
]
+
": "
+
lines
[
i
][
"ctx"
]
+
" "
if
include_answer
:
ret
+=
lines
[
i
][
"endings"
][
lines
[
i
][
"label"
]]
return
ret
def
get_few_shot_examples
(
lines
,
k
):
ret
=
""
for
i
in
range
(
k
):
ret
+=
get_one_example
(
lines
,
i
,
True
)
+
"
\n\n
"
return
ret
num_questions
=
200
num_shots
=
20
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
choices
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
choices
.
append
(
lines
[
i
][
"endings"
])
labels
.
append
(
lines
[
i
][
"label"
])
arguments
=
[{
"question"
:
q
,
"choices"
:
c
}
for
q
,
c
in
zip
(
questions
,
choices
)]
#####################################
######### SGL Program Begin #########
#####################################
import
sglang
as
sgl
@
sgl
.
function
def
few_shot_hellaswag
(
s
,
question
,
choices
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
select
(
"answer"
,
choices
=
choices
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic
=
time
.
time
()
rets
=
few_shot_hellaswag
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
64
,
progress_bar
=
True
,
)
preds
=
[
choices
[
i
].
index
(
rets
[
i
][
"answer"
])
for
i
in
range
(
len
(
rets
))]
latency
=
time
.
time
()
-
tic
# Compute accuracy
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
return
accuracy
,
latency
python/sglang/utils.py
View file @
1e495e08
...
...
@@ -4,6 +4,7 @@ import base64
import
importlib
import
json
import
logging
import
os
import
signal
import
sys
import
traceback
...
...
@@ -15,6 +16,7 @@ from typing import Union
import
numpy
as
np
import
requests
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -260,3 +262,40 @@ class LazyImport:
def
__call__
(
self
,
*
args
,
**
kwargs
):
module
=
self
.
_load
()
return
module
(
*
args
,
**
kwargs
)
def
fetch_and_cache_jsonl
(
url
,
cache_file
=
"cached_data.jsonl"
):
"""Read and cache a jsonl file from a url."""
# Check if the cache file already exists
if
os
.
path
.
exists
(
cache_file
):
print
(
"Loading data from cache..."
)
with
open
(
cache_file
,
"r"
)
as
f
:
data
=
[
json
.
loads
(
line
)
for
line
in
f
]
else
:
print
(
"Downloading data from URL..."
)
# Stream the response to show the progress bar
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
# Check for request errors
# Total size of the file in bytes
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
chunk_size
=
1024
# Download in chunks of 1KB
# Use tqdm to display the progress bar
with
open
(
cache_file
,
"wb"
)
as
f
,
tqdm
(
desc
=
cache_file
,
total
=
total_size
,
unit
=
"B"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
bar
:
for
chunk
in
response
.
iter_content
(
chunk_size
=
chunk_size
):
f
.
write
(
chunk
)
bar
.
update
(
len
(
chunk
))
# Convert the data to a list of dictionaries
with
open
(
cache_file
,
"r"
)
as
f
:
data
=
[
json
.
loads
(
line
)
for
line
in
f
]
return
data
test/lang/test_srt_backend.py
View file @
1e495e08
...
...
@@ -7,6 +7,7 @@ from sglang.test.test_programs import (
test_dtype_gen
,
test_expert_answer
,
test_few_shot_qa
,
test_hellaswag_select
,
test_mt_bench
,
test_parallel_decoding
,
test_regex
,
...
...
@@ -62,6 +63,12 @@ class TestSRTBackend(unittest.TestCase):
def
test_dtype_gen
(
self
):
test_dtype_gen
()
def
test_hellaswag_select
(
self
):
# Run twice to capture more bugs
for
_
in
range
(
2
):
accuracy
,
latency
=
test_hellaswag_select
()
assert
accuracy
>
0.71
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment