Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
442534aa
"vscode:/vscode.git/clone" did not exist on "0b3ddec6540d7fc7fb59c1b6184a5e6c9e1d32e0"
Unverified
Commit
442534aa
authored
Aug 09, 2025
by
fzyzcjy
Committed by
GitHub
Aug 09, 2025
Browse files
Add CI for gpt-oss model on hopper (#8851)
parent
de8b8b6e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
187 additions
and
2 deletions
+187
-2
python/sglang/test/run_eval.py
python/sglang/test/run_eval.py
+4
-1
python/sglang/test/simple_eval_common.py
python/sglang/test/simple_eval_common.py
+6
-0
python/sglang/test/simple_eval_gpqa.py
python/sglang/test/simple_eval_gpqa.py
+2
-0
test/srt/run_suite.py
test/srt/run_suite.py
+3
-1
test/srt/test_gpt_oss_1gpu.py
test/srt/test_gpt_oss_1gpu.py
+31
-0
test/srt/test_gpt_oss_4gpu.py
test/srt/test_gpt_oss_4gpu.py
+42
-0
test/srt/test_gpt_oss_common.py
test/srt/test_gpt_oss_common.py
+99
-0
No files found.
python/sglang/test/run_eval.py
View file @
442534aa
...
...
@@ -65,9 +65,10 @@ def run_eval(args):
sampler
=
ChatCompletionSampler
(
model
=
args
.
model
,
max_tokens
=
2048
,
max_tokens
=
getattr
(
args
,
"max_tokens"
,
2048
)
,
base_url
=
base_url
,
temperature
=
getattr
(
args
,
"temperature"
,
0.0
),
reasoning_effort
=
getattr
(
args
,
"reasoning_effort"
,
None
),
)
# Run eval
...
...
@@ -120,7 +121,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--eval-name"
,
type
=
str
,
default
=
"mmlu"
)
parser
.
add_argument
(
"--num-examples"
,
type
=
int
)
parser
.
add_argument
(
"--num-threads"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
2048
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--reasoning-effort"
,
type
=
str
)
args
=
parser
.
parse_args
()
run_eval
(
args
)
python/sglang/test/simple_eval_common.py
View file @
442534aa
...
...
@@ -91,6 +91,7 @@ class ChatCompletionSampler(SamplerBase):
model
:
Optional
[
str
]
=
None
,
system_message
:
Optional
[
str
]
=
None
,
temperature
:
float
=
0.0
,
reasoning_effort
:
Optional
[
str
]
=
None
,
max_tokens
:
int
=
2048
,
):
self
.
client
=
OpenAI
(
base_url
=
base_url
,
http_client
=
LargerHttpxClient
())
...
...
@@ -102,7 +103,11 @@ class ChatCompletionSampler(SamplerBase):
self
.
system_message
=
system_message
self
.
temperature
=
temperature
self
.
max_tokens
=
max_tokens
self
.
reasoning_effort
=
reasoning_effort
self
.
image_format
=
"url"
print
(
f
"ChatCompletionSampler initialized with
{
self
.
system_message
=
}
{
self
.
temperature
=
}
{
self
.
max_tokens
=
}
{
self
.
reasoning_effort
=
}
"
)
def
_handle_image
(
self
,
...
...
@@ -138,6 +143,7 @@ class ChatCompletionSampler(SamplerBase):
messages
=
message_list
,
temperature
=
self
.
temperature
,
max_tokens
=
self
.
max_tokens
,
reasoning_effort
=
self
.
reasoning_effort
,
)
return
response
.
choices
[
0
].
message
.
content
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
...
...
python/sglang/test/simple_eval_gpqa.py
View file @
442534aa
...
...
@@ -71,6 +71,8 @@ class GPQAEval(Eval):
)
]
response_text
=
sampler
(
prompt_messages
)
if
response_text
is
None
:
response_text
=
""
match
=
re
.
search
(
ANSWER_PATTERN_MULTICHOICE
,
response_text
)
extracted_answer
=
match
.
group
(
1
)
if
match
else
None
score
=
1.0
if
extracted_answer
==
correct_answer
else
0.0
...
...
test/srt/run_suite.py
View file @
442534aa
...
...
@@ -63,6 +63,7 @@ suites = {
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_function_call_parser.py"
,
10
),
TestFile
(
"test_fused_moe.py"
,
30
),
TestFile
(
"test_gpt_oss_1gpu.py"
,
600
),
TestFile
(
"test_hicache.py"
,
116
),
TestFile
(
"test_hicache_mla.py"
,
127
),
TestFile
(
"test_hicache_storage.py"
,
127
),
...
...
@@ -104,7 +105,7 @@ suites = {
TestFile
(
"test_utils_update_weights.py"
,
48
),
TestFile
(
"test_vision_chunked_prefill.py"
,
175
),
TestFile
(
"test_vlm_input_format.py"
,
300
),
TestFile
(
"test_vision_openai_server_a.py"
,
584
),
TestFile
(
"test_vision_openai_server_a.py"
,
989
),
TestFile
(
"test_vision_openai_server_b.py"
,
620
),
TestFile
(
"test_w8a8_quantization.py"
,
46
),
TestFile
(
"test_reasoning_parser.py"
,
5
),
...
...
@@ -176,6 +177,7 @@ suites = {
TestFile
(
"test_update_weights_from_distributed.py"
,
103
),
],
"per-commit-4-gpu"
:
[
TestFile
(
"test_gpt_oss_4gpu.py"
,
600
),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_pp_single_node.py"
,
372
),
TestFile
(
"test_multi_instance_release_memory_occupation.py"
,
64
),
...
...
test/srt/test_gpt_oss_1gpu.py
0 → 100644
View file @
442534aa
import
unittest
from
test_gpt_oss_common
import
BaseTestGptOss
class
TestGptOss1Gpu
(
BaseTestGptOss
):
def
test_mxfp4_20b
(
self
):
self
.
run_test
(
model_variant
=
"20b"
,
quantization
=
"mxfp4"
,
expected_score_of_reasoning_effort
=
{
"low"
:
0.38
,
"medium"
:
0.38
,
"high"
:
0.29
,
# TODO investigate
},
)
def
test_bf16_20b
(
self
):
self
.
run_test
(
model_variant
=
"20b"
,
quantization
=
"bf16"
,
expected_score_of_reasoning_effort
=
{
"low"
:
0.38
,
"medium"
:
0.38
,
"high"
:
0.29
,
# TODO investigate
},
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_gpt_oss_4gpu.py
0 → 100644
View file @
442534aa
import
unittest
from
test_gpt_oss_common
import
BaseTestGptOss
class
TestGptOss4Gpu
(
BaseTestGptOss
):
def
test_bf16_120b
(
self
):
self
.
run_test
(
model_variant
=
"120b"
,
quantization
=
"bf16"
,
expected_score_of_reasoning_effort
=
{
"low"
:
0.61
,
# remove to speed up
# "medium": 0.61,
# "high": 0.61,
},
other_args
=
[
"--tp"
,
"4"
,
"--cuda-graph-max-bs"
,
"200"
],
)
def
test_mxfp4_120b
(
self
):
self
.
run_test
(
model_variant
=
"120b"
,
quantization
=
"mxfp4"
,
expected_score_of_reasoning_effort
=
{
"low"
:
0.61
,
# remove to speed up
# "medium": 0.61,
# "high": 0.61,
},
other_args
=
[
"--tp"
,
"4"
,
"--cuda-graph-max-bs"
,
"200"
,
"--mem-fraction-static"
,
"0.93"
,
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_gpt_oss_common.py
0 → 100644
View file @
442534aa
from
concurrent.futures
import
ThreadPoolExecutor
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Literal
,
Optional
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
_base_url
=
DEFAULT_URL_FOR_TEST
class
BaseTestGptOss
(
CustomTestCase
):
def
run_test
(
self
,
model_variant
:
Literal
[
"20b"
,
"120b"
],
quantization
:
Literal
[
"mxfp4"
,
"bf16"
],
expected_score_of_reasoning_effort
:
Dict
[
str
,
float
],
other_args
:
Optional
[
List
[
str
]]
=
None
,
):
if
other_args
is
None
:
other_args
=
[]
model
=
{
(
"20b"
,
"bf16"
):
"lmsys/gpt-oss-20b-bf16"
,
(
"120b"
,
"bf16"
):
"lmsys/gpt-oss-120b-bf16"
,
(
"20b"
,
"mxfp4"
):
"openai/gpt-oss-20b"
,
(
"120b"
,
"mxfp4"
):
"openai/gpt-oss-120b"
,
}[(
model_variant
,
quantization
)]
if
model_variant
==
"20b"
:
other_args
+=
[
"--cuda-graph-max-bs"
,
"600"
]
self
.
_run_test_raw
(
model
=
model
,
expected_score_of_reasoning_effort
=
expected_score_of_reasoning_effort
,
other_args
=
other_args
,
)
def
_run_test_raw
(
self
,
model
:
str
,
expected_score_of_reasoning_effort
:
Dict
[
str
,
float
],
other_args
:
List
[
str
],
):
process
=
popen_launch_server
(
model
,
_base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
try
:
# run multiple tests in parallel since we are mostly bound by the longest generate sequence
# instead of the number of questions
with
ThreadPoolExecutor
(
max_workers
=
4
)
as
executor
:
list
(
executor
.
map
(
lambda
d
:
self
.
_run_one_eval
(
**
d
),
[
dict
(
model
=
model
,
reasoning_effort
=
reasoning_effort
,
expected_score
=
expected_score
,
)
for
reasoning_effort
,
expected_score
in
expected_score_of_reasoning_effort
.
items
()
],
)
)
finally
:
kill_process_tree
(
process
.
pid
)
def
_run_one_eval
(
self
,
model
,
reasoning_effort
,
expected_score
):
args
=
SimpleNamespace
(
base_url
=
_base_url
,
model
=
model
,
eval_name
=
"gpqa"
,
num_examples
=
198
,
# use enough threads to allow parallelism
num_threads
=
198
,
# TODO 4k is still not enough, we need e.g. 64k token, but that is super slow
# otherwise a lot of questions are not answered
max_tokens
=
4096
,
# simple-evals by default use 0.5 and is better than 0.0 temperature
# but here for reproducibility, we use 0.1
temperature
=
0.1
,
reasoning_effort
=
reasoning_effort
,
)
print
(
f
"Evaluation start:
{
model
=
}
{
reasoning_effort
=
}
{
expected_score
=
}
"
)
metrics
=
run_eval
(
args
)
print
(
f
"Evaluation end:
{
model
=
}
{
reasoning_effort
=
}
{
expected_score
=
}
{
metrics
=
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
expected_score
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment