Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af02f99b
"src/vscode:/vscode.git/clone" did not exist on "b02d0d6be3bd3e1a38e1321867ef28a8f7754d82"
Unverified
Commit
af02f99b
authored
Jan 26, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 26, 2025
Browse files
Add more logprob tests (#3162)
parent
9472e699
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
2 deletions
+115
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+115
-2
No files found.
test/srt/test_srt_endpoint.py
View file @
af02f99b
...
...
@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--enable-custom-logit-processor"
,),
other_args
=
(
"--enable-custom-logit-processor"
,
"--mem-fraction-static"
,
"0.8"
,
),
)
@
classmethod
...
...
@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
},
"return_logprob"
:
True
,
"logprob_start_len"
:
-
1
,
"top_logprobs_num"
:
5
,
},
)
response_json
=
response
.
json
()
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
#
print(json.dumps(response_json, indent=2))
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
# Test the number of tokens are correct
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
]),
new_tokens
)
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
for
i
in
range
(
new_tokens
):
self
.
assertListEqual
(
res
[
"meta_info"
][
"output_token_logprobs"
][
i
],
res
[
"meta_info"
][
"output_top_logprobs"
][
i
][
0
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
][
i
]),
5
)
def
test_logprob_match
(
self
):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
...
...
@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.25
)
def
run_logprob_check
(
self
,
arg
):
(
input_len
,
output_len
,
temperature
,
logprob_start_len
,
return_logprob
,
top_logprobs_num
,
)
=
arg
input_ids
=
list
(
range
(
input_len
))
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
},
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
},
)
response_json
=
response
.
json
()
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"prompt_tokens"
],
input_len
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
output_len
)
# Test the number of tokens are correct
if
return_logprob
:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta
=
0
if
logprob_start_len
==
0
else
1
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"input_token_logprobs"
])
+
logprob_start_len
+
delta
,
res
[
"meta_info"
][
"prompt_tokens"
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
output_len
)
if
top_logprobs_num
:
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"input_top_logprobs"
])
+
logprob_start_len
+
delta
,
res
[
"meta_info"
][
"prompt_tokens"
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
]),
output_len
)
for
i
in
range
(
output_len
):
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
][
i
]),
top_logprobs_num
,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if
temperature
==
0
:
self
.
assertListEqual
(
res
[
"meta_info"
][
"output_token_logprobs"
][
i
],
res
[
"meta_info"
][
"output_top_logprobs"
][
i
][
0
],
)
def
test_logprob_mixed
(
self
):
args
=
[]
temperature
=
0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for
input_len
in
[
1000
,
2000
]:
for
output_len
in
[
4
,
8
]:
for
logprob_start_len
in
[
0
,
500
,
1000
]:
for
return_logprob
in
[
True
,
False
]:
for
top_logprobs_num
in
[
0
,
5
]:
if
logprob_start_len
>=
input_len
:
continue
args
.
append
(
(
input_len
,
output_len
,
temperature
,
logprob_start_len
,
return_logprob
,
top_logprobs_num
,
)
)
random
.
shuffle
(
args
)
with
ThreadPoolExecutor
(
8
)
as
executor
:
list
(
executor
.
map
(
self
.
run_logprob_check
,
args
))
def
test_logprob_grammar
(
self
):
prompts
=
"Question: Is Paris the Capital of France? Answer:"
allowed_tokens
=
[
" Yes"
,
" No"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment