Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ae7ee01a
Unverified
Commit
ae7ee01a
authored
Aug 01, 2024
by
Ying Sheng
Committed by
GitHub
Aug 01, 2024
Browse files
Add accuracy test to CI: MMLU (#882)
parent
76e59088
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
111 additions
and
32 deletions
+111
-32
test/srt/deprecated/test_robust.py
test/srt/deprecated/test_robust.py
+0
-0
test/srt/test_eval_accuracy.py
test/srt/test_eval_accuracy.py
+43
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+4
-32
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+64
-0
No files found.
test/srt/
ol
d/test_robust.py
→
test/srt/
deprecate
d/test_robust.py
View file @
ae7ee01a
File moved
test/srt/test_eval_accuracy.py
0 → 100644
View file @
ae7ee01a
import
json
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
MODEL_NAME_FOR_TEST
,
popen_launch_server
class
TestAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
port
=
30000
cls
.
model
=
MODEL_NAME_FOR_TEST
cls
.
base_url
=
f
"http://localhost:
{
port
}
"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
port
,
timeout
=
300
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
20
,
num_threads
=
20
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.5
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
# t = TestAccuracy()
# t.setUpClass()
# t.test_mmlu()
# t.tearDownClass()
test/srt/test_openai_server.py
View file @
ae7ee01a
import
json
import
subprocess
import
time
import
unittest
import
openai
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
MODEL_NAME_FOR_TEST
,
popen_launch_server
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
port
=
30000
timeout
=
300
command
=
[
"python3"
,
"-m"
,
"sglang.launch_server"
,
"--model-path"
,
model
,
"--host"
,
"localhost"
,
"--port"
,
str
(
port
),
]
cls
.
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
cls
.
model
=
MODEL_NAME_FOR_TEST
cls
.
base_url
=
f
"http://localhost:
{
port
}
/v1"
cls
.
model
=
model
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
try
:
response
=
requests
.
get
(
f
"
{
cls
.
base_url
}
/models"
)
if
response
.
status_code
==
200
:
return
except
requests
.
RequestException
:
pass
time
.
sleep
(
10
)
raise
TimeoutError
(
"Server failed to start within the timeout period."
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
port
,
timeout
=
300
)
@
classmethod
def
tearDownClass
(
cls
):
...
...
@@ -178,8 +152,6 @@ class TestOpenAIServer(unittest.TestCase):
is_first
=
True
for
response
in
generator
:
print
(
response
)
data
=
response
.
choices
[
0
].
delta
if
is_first
:
data
.
role
==
"assistant"
...
...
test/srt/test_srt_endpoint.py
0 → 100644
View file @
ae7ee01a
import
json
import
unittest
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
MODEL_NAME_FOR_TEST
,
popen_launch_server
class
TestSRTEndpoint
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
port
=
30000
cls
.
model
=
MODEL_NAME_FOR_TEST
cls
.
base_url
=
f
"http://localhost:
{
port
}
"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
port
,
timeout
=
300
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
return_text
=
False
,
n
=
1
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
32
,
"n"
:
n
,
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"return_text_in_logprobs"
:
return_text
,
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
def
test_simple_decode
(
self
):
self
.
run_decode
()
def
test_parallel_sample
(
self
):
self
.
run_decode
(
n
=
3
)
def
test_logprob
(
self
):
for
top_logprobs_num
in
[
0
,
3
]:
for
return_text
in
[
True
,
False
]:
self
.
run_decode
(
return_logprob
=
True
,
top_logprobs_num
=
top_logprobs_num
,
return_text
=
return_text
,
)
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment