Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af02f99b
Unverified
Commit
af02f99b
authored
Jan 26, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 26, 2025
Browse files
Add more logprob tests (#3162)
parent
9472e699
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
2 deletions
+115
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+115
-2
No files found.
test/srt/test_srt_endpoint.py
View file @
af02f99b
...
@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--enable-custom-logit-processor"
,),
other_args
=
(
"--enable-custom-logit-processor"
,
"--mem-fraction-static"
,
"0.8"
,
),
)
)
@
classmethod
@
classmethod
...
@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
},
},
"return_logprob"
:
True
,
"return_logprob"
:
True
,
"logprob_start_len"
:
-
1
,
"logprob_start_len"
:
-
1
,
"top_logprobs_num"
:
5
,
},
},
)
)
response_json
=
response
.
json
()
response_json
=
response
.
json
()
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
#
print(json.dumps(response_json, indent=2))
res
=
response_json
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
# Test the number of tokens are correct
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
]),
new_tokens
)
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
for
i
in
range
(
new_tokens
):
self
.
assertListEqual
(
res
[
"meta_info"
][
"output_token_logprobs"
][
i
],
res
[
"meta_info"
][
"output_top_logprobs"
][
i
][
0
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
][
i
]),
5
)
def
test_logprob_match
(
self
):
def
test_logprob_match
(
self
):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
...
@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff
=
np
.
max
(
diff
)
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.25
)
self
.
assertLess
(
max_diff
,
0.25
)
def
run_logprob_check
(
self
,
arg
):
(
input_len
,
output_len
,
temperature
,
logprob_start_len
,
return_logprob
,
top_logprobs_num
,
)
=
arg
input_ids
=
list
(
range
(
input_len
))
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
},
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
},
)
response_json
=
response
.
json
()
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"prompt_tokens"
],
input_len
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
output_len
)
# Test the number of tokens are correct
if
return_logprob
:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta
=
0
if
logprob_start_len
==
0
else
1
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"input_token_logprobs"
])
+
logprob_start_len
+
delta
,
res
[
"meta_info"
][
"prompt_tokens"
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
output_len
)
if
top_logprobs_num
:
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"input_top_logprobs"
])
+
logprob_start_len
+
delta
,
res
[
"meta_info"
][
"prompt_tokens"
],
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
]),
output_len
)
for
i
in
range
(
output_len
):
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_top_logprobs"
][
i
]),
top_logprobs_num
,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if
temperature
==
0
:
self
.
assertListEqual
(
res
[
"meta_info"
][
"output_token_logprobs"
][
i
],
res
[
"meta_info"
][
"output_top_logprobs"
][
i
][
0
],
)
def
test_logprob_mixed
(
self
):
args
=
[]
temperature
=
0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for
input_len
in
[
1000
,
2000
]:
for
output_len
in
[
4
,
8
]:
for
logprob_start_len
in
[
0
,
500
,
1000
]:
for
return_logprob
in
[
True
,
False
]:
for
top_logprobs_num
in
[
0
,
5
]:
if
logprob_start_len
>=
input_len
:
continue
args
.
append
(
(
input_len
,
output_len
,
temperature
,
logprob_start_len
,
return_logprob
,
top_logprobs_num
,
)
)
random
.
shuffle
(
args
)
with
ThreadPoolExecutor
(
8
)
as
executor
:
list
(
executor
.
map
(
self
.
run_logprob_check
,
args
))
def
test_logprob_grammar
(
self
):
def
test_logprob_grammar
(
self
):
prompts
=
"Question: Is Paris the Capital of France? Answer:"
prompts
=
"Question: Is Paris the Capital of France? Answer:"
allowed_tokens
=
[
" Yes"
,
" No"
]
allowed_tokens
=
[
" Yes"
,
" No"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment