Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
777eb538
Unverified
Commit
777eb538
authored
Sep 27, 2025
by
Mick
Committed by
GitHub
Sep 26, 2025
Browse files
ci: refactor nightly test (#10495)
parent
05a35266
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1656 additions
and
187 deletions
+1656
-187
.github/workflows/nightly-test.yml
.github/workflows/nightly-test.yml
+78
-4
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+6
-8
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+306
-32
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+5
-1
python/sglang/srt/managers/scheduler_profiler_mixin.py
python/sglang/srt/managers/scheduler_profiler_mixin.py
+3
-3
python/sglang/test/run_eval.py
python/sglang/test/run_eval.py
+7
-0
python/sglang/test/simple_eval_mmmu_vlm.py
python/sglang/test/simple_eval_mmmu_vlm.py
+441
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+136
-0
scripts/ci/publish_traces.py
scripts/ci/publish_traces.py
+263
-0
test/srt/run_suite.py
test/srt/run_suite.py
+0
-3
test/srt/test_nightly_gsm8k_eval_amd.py
test/srt/test_nightly_gsm8k_eval_amd.py
+2
-29
test/srt/test_nightly_text_models_gsm8k_eval.py
test/srt/test_nightly_text_models_gsm8k_eval.py
+21
-82
test/srt/test_nightly_text_models_perf.py
test/srt/test_nightly_text_models_perf.py
+135
-0
test/srt/test_nightly_vlms_mmmu_eval.py
test/srt/test_nightly_vlms_mmmu_eval.py
+117
-0
test/srt/test_nightly_vlms_perf.py
test/srt/test_nightly_vlms_perf.py
+135
-0
test/srt/test_vllm_dependency.py
test/srt/test_vllm_dependency.py
+1
-25
No files found.
.github/workflows/nightly-test.yml
View file @
777eb538
...
@@ -15,8 +15,8 @@ concurrency:
...
@@ -15,8 +15,8 @@ concurrency:
cancel-in-progress
:
true
cancel-in-progress
:
true
jobs
:
jobs
:
nightly-test
:
nightly-test
-eval-text-models
:
if
:
github.repository == 'sgl-project/sglang'
|| github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang'
runs-on
:
2-gpu-runner
runs-on
:
2-gpu-runner
steps
:
steps
:
-
name
:
Checkout code
-
name
:
Checkout code
...
@@ -26,8 +26,82 @@ jobs:
...
@@ -26,8 +26,82 @@ jobs:
run
:
|
run
:
|
bash scripts/ci/ci_install_dependency.sh
bash scripts/ci/ci_install_dependency.sh
-
name
:
Run
test
-
name
:
Run
eval test for text models
timeout-minutes
:
120
timeout-minutes
:
120
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite nightly --timeout-per-file 3600
python3 test_nightly_text_models_gsm8k_eval.py
nightly-test-perf-text-models
:
if
:
github.repository == 'sgl-project/sglang'
runs-on
:
2-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v4
-
name
:
Install dependencies
run
:
|
bash scripts/ci/ci_install_dependency.sh
-
name
:
Run performance test for text models
timeout-minutes
:
180
env
:
TRACE_BASE_URL
:
https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
PERFETTO_RELAY_URL
:
${{ vars.PERFETTO_RELAY_URL }}
run
:
|
rm -rf test/srt/performance_profiles_text_models/
python3 test/srt/test_nightly_text_models_perf.py
-
name
:
Publish traces to storage repo
env
:
GITHUB_TOKEN
:
${{ secrets.GH_PAT_FOR_NIGHTLY_CI }}
GITHUB_RUN_ID
:
${{ github.run_id }}
GITHUB_RUN_NUMBER
:
${{ github.run_number }}
run
:
|
python3 scripts/ci/publish_traces.py
nightly-test-eval-vlms
:
if
:
github.repository == 'sgl-project/sglang'
runs-on
:
1-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v4
-
name
:
Install dependencies
run
:
|
bash scripts/ci/ci_install_dependency.sh
-
name
:
Run eval test for VLM models (fixed MMMU-100)
timeout-minutes
:
240
run
:
|
cd test/srt
python3 test_nightly_vlms_mmmu_eval.py
nightly-test-perf-vlms
:
if
:
github.repository == 'sgl-project/sglang'
runs-on
:
1-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v4
-
name
:
Install dependencies
run
:
|
bash scripts/ci/ci_install_dependency.sh
-
name
:
Run perf test for VLM models (MMMU)
timeout-minutes
:
240
env
:
TRACE_BASE_URL
:
https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
PERFETTO_RELAY_URL
:
${{ vars.PERFETTO_RELAY_URL }}
run
:
|
rm -rf test/srt/performance_profiles_vlms/
python3 test/srt/test_nightly_vlms_perf.py
-
name
:
Publish traces to storage repo
env
:
GITHUB_TOKEN
:
${{ secrets.GH_PAT_FOR_NIGHTLY_CI }}
GITHUB_RUN_ID
:
${{ github.run_id }}
GITHUB_RUN_NUMBER
:
${{ github.run_number }}
run
:
|
python3 scripts/ci/publish_traces.py --vlm
python/sglang/bench_one_batch.py
View file @
777eb538
...
@@ -443,11 +443,9 @@ def latency_test_run_once(
...
@@ -443,11 +443,9 @@ def latency_test_run_once(
if
profile
:
if
profile
:
profiler
.
stop
()
profiler
.
stop
()
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
trace_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
_save_profile_trace_results
(
profiler
,
profile_filename
)
_save_profile_trace_results
(
profiler
,
trace_filename
)
rank_print
(
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
trace_filename
}
"
)
f
"torch profiler chrome trace for prefill saved to
{
profile_filename
}
"
)
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
...
@@ -479,10 +477,10 @@ def latency_test_run_once(
...
@@ -479,10 +477,10 @@ def latency_test_run_once(
if
profile
and
i
==
output_len
/
2
:
if
profile
and
i
==
output_len
/
2
:
profiler
.
stop
()
profiler
.
stop
()
profil
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
trac
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
_save_profile_trace_results
(
profiler
,
profil
e_filename
)
_save_profile_trace_results
(
profiler
,
trac
e_filename
)
rank_print
(
rank_print
(
f
"torch profiler chrome trace for decoding 1 token saved to
{
profil
e_filename
}
"
f
"torch profiler chrome trace for decoding 1 token saved to
{
trac
e_filename
}
"
)
)
# Record decode timing from 2nd output
# Record decode timing from 2nd output
...
...
python/sglang/bench_one_batch_server.py
View file @
777eb538
...
@@ -9,6 +9,7 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
...
@@ -9,6 +9,7 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
"""
"""
import
argparse
import
argparse
...
@@ -19,12 +20,17 @@ import multiprocessing
...
@@ -19,12 +20,17 @@ import multiprocessing
import
os
import
os
import
random
import
random
import
time
import
time
from
typing
import
List
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
from
pydantic
import
BaseModel
from
sglang.bench_serving
import
get_tokenizer
,
sample_random_requests
from
sglang.bench_serving
import
(
get_tokenizer
,
sample_mmmu_requests
,
sample_random_requests
,
)
from
sglang.profiler
import
run_profile
from
sglang.profiler
import
run_profile
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -32,6 +38,109 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
...
@@ -32,6 +38,109 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
class
ProfileLinks
(
BaseModel
):
"""Pydantic model for profile trace links."""
extend
:
Optional
[
str
]
=
None
decode
:
Optional
[
str
]
=
None
class
BenchmarkResult
(
BaseModel
):
"""Pydantic model for benchmark results table data, for a single isl and osl"""
model_path
:
str
run_name
:
str
batch_size
:
int
input_len
:
int
output_len
:
int
latency
:
float
ttft
:
float
input_throughput
:
float
output_throughput
:
float
overall_throughput
:
float
last_gen_throughput
:
float
acc_length
:
Optional
[
float
]
=
None
profile_links
:
Optional
[
ProfileLinks
]
=
None
@
staticmethod
def
help_str
()
->
str
:
return
f
"""
Note: To view the traces through perfetto-ui, please:
1. use Google Chrome
2. enable popup
"""
def
to_markdown_row
(
self
,
trace_dir
,
base_url
:
str
=
""
,
relay_base
:
str
=
""
)
->
str
:
"""Convert this benchmark result to a markdown table row."""
# Calculate costs (assuming H100 pricing for now)
hourly_cost_per_gpu
=
2
# $2/hour for one H100
hourly_cost
=
hourly_cost_per_gpu
*
1
# Assuming tp_size = 1 for simplicity
input_util
=
0.7
accept_length
=
(
round
(
self
.
acc_length
,
2
)
if
self
.
acc_length
is
not
None
else
"n/a"
)
itl
=
1
/
(
self
.
output_throughput
/
self
.
batch_size
)
*
1000
input_cost
=
1e6
/
(
self
.
input_throughput
*
input_util
)
/
3600
*
hourly_cost
output_cost
=
1e6
/
self
.
output_throughput
/
3600
*
hourly_cost
def
get_perfetto_relay_link_from_trace_file
(
trace_file
:
str
):
import
os
from
urllib.parse
import
quote
rel_path
=
os
.
path
.
relpath
(
trace_file
,
trace_dir
)
raw_file_link
=
f
"
{
base_url
}
/
{
rel_path
}
"
relay_link
=
(
f
"
{
relay_base
}
?src=
{
quote
(
raw_file_link
,
safe
=
''
)
}
"
if
relay_base
and
quote
else
raw_file_link
)
return
relay_link
# Handle profile links
profile_link
=
"NA | NA"
if
self
.
profile_links
:
if
self
.
profile_links
.
extend
or
self
.
profile_links
.
decode
:
# Create a combined link or use the first available one
trace_files
=
[
self
.
profile_links
.
extend
,
self
.
profile_links
.
decode
]
trace_files_relay_links
=
[
f
"[trace](
{
get_perfetto_relay_link_from_trace_file
(
trace_file
)
}
)"
for
trace_file
in
trace_files
]
profile_link
=
" | "
.
join
(
trace_files_relay_links
)
# Build the row
return
f
"|
{
self
.
batch_size
}
|
{
self
.
input_len
}
|
{
self
.
latency
:.
2
f
}
|
{
self
.
input_throughput
:.
2
f
}
|
{
self
.
output_throughput
:.
2
f
}
|
{
accept_length
}
|
{
itl
:.
2
f
}
|
{
input_cost
:.
2
f
}
|
{
output_cost
:.
2
f
}
|
{
profile_link
}
|
\n
"
@
classmethod
def
generate_markdown_report
(
cls
,
trace_dir
,
results
:
List
[
"BenchmarkResult"
]
)
->
str
:
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
import
os
summary
=
f
"###
{
results
[
0
].
model_path
}
\n
"
# summary += (
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
# )
summary
+=
"| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|
\n
"
summary
+=
"| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |
\n
"
# all results should share the same isl & osl
for
result
in
results
:
base_url
=
os
.
getenv
(
"TRACE_BASE_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
os
.
getenv
(
"PERFETTO_RELAY_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
"https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
# base_url = "https://github.com/sgl-project/ci-data/traces"
summary
+=
result
.
to_markdown_row
(
trace_dir
,
base_url
,
relay_base
)
return
summary
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
run_name
:
str
=
"default"
run_name
:
str
=
"default"
...
@@ -50,8 +159,12 @@ class BenchArgs:
...
@@ -50,8 +159,12 @@ class BenchArgs:
profile
:
bool
=
False
profile
:
bool
=
False
profile_steps
:
int
=
3
profile_steps
:
int
=
3
profile_by_stage
:
bool
=
False
profile_by_stage
:
bool
=
False
profile_filename_prefix
:
str
=
None
append_to_github_summary
:
bool
=
True
dataset_path
:
str
=
""
dataset_path
:
str
=
""
parallel_batch
:
bool
=
False
parallel_batch
:
bool
=
False
dataset_name
:
str
=
"random"
output_path
:
Optional
[
str
]
=
None
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -67,6 +180,13 @@ class BenchArgs:
...
@@ -67,6 +180,13 @@ class BenchArgs:
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--dataset-name"
,
type
=
str
,
default
=
BenchArgs
.
dataset_name
,
choices
=
[
"mmmu"
,
"random"
],
help
=
"Name of the dataset to benchmark on."
,
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--client-stream-interval"
,
"--client-stream-interval"
,
...
@@ -96,14 +216,36 @@ class BenchArgs:
...
@@ -96,14 +216,36 @@ class BenchArgs:
help
=
"Path to the dataset."
,
help
=
"Path to the dataset."
,
)
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--profile-filename-prefix"
,
type
=
str
,
default
=
BenchArgs
.
profile_filename_prefix
,
)
parser
.
add_argument
(
"--no-append-to-github-summary"
,
action
=
"store_false"
,
dest
=
"append_to_github_summary"
,
help
=
"Disable appending the output of this run to github ci summary"
,
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
BenchArgs
.
output_path
,
help
=
"Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
# use the default value's type to cast the args into correct types.
# use the default value's type to cast the args into correct types.
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
kwargs
=
{}
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
for
attr
,
attr_type
in
attrs
:
)
val
=
getattr
(
args
,
attr
)
if
attr_type
is
type
(
None
):
kwargs
[
attr
]
=
val
else
:
kwargs
[
attr
]
=
attr_type
(
val
)
return
cls
(
**
kwargs
)
def
launch_server_internal
(
server_args
):
def
launch_server_internal
(
server_args
):
...
@@ -148,13 +290,25 @@ def run_one_case(
...
@@ -148,13 +290,25 @@ def run_one_case(
run_name
:
str
,
run_name
:
str
,
result_filename
:
str
,
result_filename
:
str
,
tokenizer
,
tokenizer
,
dataset_name
=
""
,
profile
:
bool
=
False
,
profile
:
bool
=
False
,
profile_steps
:
int
=
3
,
profile_steps
:
int
=
3
,
profile_by_stage
:
bool
=
False
,
profile_by_stage
:
bool
=
False
,
profile_filename_prefix
:
str
=
None
,
dataset_path
:
str
=
""
,
dataset_path
:
str
=
""
,
parallel_batch
:
bool
=
False
,
parallel_batch
:
bool
=
False
,
):
):
requests
.
post
(
url
+
"/flush_cache"
)
requests
.
post
(
url
+
"/flush_cache"
)
# TODO: reuse bench_serving.get_dataset ?
if
dataset_name
==
"mmmu"
:
input_requests
=
sample_mmmu_requests
(
num_requests
=
batch_size
,
tokenizer
=
tokenizer
,
fixed_output_len
=
output_len
,
apply_chat_template
=
True
,
random_sample
=
False
,
)
elif
dataset_name
==
"random"
:
input_requests
=
sample_random_requests
(
input_requests
=
sample_random_requests
(
input_len
=
input_len
,
input_len
=
input_len
,
output_len
=
output_len
,
output_len
=
output_len
,
...
@@ -181,15 +335,22 @@ def run_one_case(
...
@@ -181,15 +335,22 @@ def run_one_case(
profile_link
=
None
profile_link
=
None
if
profile
:
if
profile
:
output_dir
,
profile_name
=
None
,
None
if
profile_filename_prefix
:
output_dir
=
os
.
path
.
dirname
(
profile_filename_prefix
)
profile_name
=
os
.
path
.
basename
(
profile_filename_prefix
)
profile_link
:
str
=
run_profile
(
profile_link
:
str
=
run_profile
(
url
,
profile_steps
,
[
"CPU"
,
"GPU"
],
None
,
None
,
profile_by_stage
url
,
profile_steps
,
[
"CPU"
,
"GPU"
],
output_dir
,
profile_name
,
profile_by_stage
,
)
)
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
response
=
requests
.
post
(
url
+
"/generate"
,
payload
=
{
json
=
{
"input_ids"
:
[
req
.
prompt
for
req
in
input_requests
],
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
temperature
,
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"max_new_tokens"
:
output_len
,
...
@@ -200,7 +361,22 @@ def run_one_case(
...
@@ -200,7 +361,22 @@ def run_one_case(
"return_logprob"
:
return_logprob
,
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
"stream"
:
True
,
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
},
}
if
dataset_name
==
"mmmu"
:
# vlm
input_ids
=
[]
for
input_req
in
input_requests
:
input_ids
+=
[
tokenizer
.
encode
(
input_req
.
prompt
)]
payload
[
"image_data"
]
=
[
req
.
image_data
for
req
in
input_requests
]
else
:
input_ids
=
[
req
.
prompt
for
req
in
input_requests
]
payload
[
"input_ids"
]
=
input_ids
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
payload
,
stream
=
True
,
stream
=
True
,
)
)
...
@@ -264,10 +440,100 @@ def run_one_case(
...
@@ -264,10 +440,100 @@ def run_one_case(
overall_throughput
,
overall_throughput
,
last_gen_throughput
,
last_gen_throughput
,
acc_length
,
acc_length
,
profile_link
if
profile
else
None
,
profile_link
,
)
)
def
save_results_as_json
(
result
:
List
[
Tuple
],
bench_args
:
BenchArgs
,
model
:
str
):
"""Save benchmark results as JSON using Pydantic models."""
json_results
=
[]
# Generate all parameter combinations to match with results
param_combinations
=
list
(
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
)
)
for
i
,
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
profile_link
,
)
in
enumerate
(
result
):
# Get the corresponding parameters for this result
bs
,
input_len
,
output_len
=
param_combinations
[
i
]
# Parse profile links if available
profile_links
=
None
if
profile_link
:
profile_links
=
parse_profile_links
(
profile_link
,
batch_size
,
input_len
,
output_len
)
benchmark_result
=
BenchmarkResult
(
model_path
=
model
,
run_name
=
bench_args
.
run_name
,
batch_size
=
batch_size
,
input_len
=
input_len
,
output_len
=
output_len
,
latency
=
latency
,
ttft
=
ttft
,
input_throughput
=
input_throughput
,
output_throughput
=
output_throughput
,
overall_throughput
=
overall_throughput
,
last_gen_throughput
=
last_gen_throughput
,
acc_length
=
acc_length
,
profile_links
=
profile_links
,
)
json_results
.
append
(
benchmark_result
.
model_dump
())
# Save to JSON file
with
open
(
bench_args
.
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
json_results
,
f
,
indent
=
2
,
ensure_ascii
=
False
)
print
(
f
"Results saved as JSON to
{
bench_args
.
output_path
}
"
)
def
parse_profile_links
(
profile_dir
:
str
,
batch_size
:
int
,
input_len
:
int
,
output_len
:
int
)
->
Optional
[
ProfileLinks
]:
"""Parse profile directory to extract extend and decode trace file links."""
if
not
profile_dir
or
not
os
.
path
.
exists
(
profile_dir
):
return
None
extend_link
=
None
decode_link
=
None
# Look for extend/prefill trace files
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
"extend"
in
file
.
lower
()
or
"prefill"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
# If no specific extend/decode files found, try to find files with batch/input/output info
if
not
extend_link
or
not
decode_link
:
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
f
"_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_"
in
file
:
if
"prefill"
in
file
.
lower
()
or
"extend"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
if
extend_link
or
decode_link
:
return
ProfileLinks
(
extend
=
extend_link
,
decode
=
decode_link
)
return
None
def
get_report_summary
(
def
get_report_summary
(
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
):
...
@@ -358,6 +624,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -358,6 +624,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
dataset_name
=
bench_args
.
dataset_name
,
run_name
=
""
,
run_name
=
""
,
result_filename
=
""
,
result_filename
=
""
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -384,10 +651,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -384,10 +651,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
dataset_name
=
bench_args
.
dataset_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)
)
)
)
...
@@ -410,11 +679,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -410,11 +679,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_name
=
bench_args
.
dataset_name
,
profile
=
bench_args
.
profile
,
profile
=
bench_args
.
profile
,
profile_steps
=
bench_args
.
profile_steps
,
profile_steps
=
bench_args
.
profile_steps
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)[
-
1
],
)[
-
1
],
)
)
)
)
...
@@ -427,13 +698,16 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -427,13 +698,16 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
# Save results as JSON if output_path is specified
if
bench_args
.
output_path
:
save_results_as_json
(
result
,
bench_args
,
model
=
server_args
.
model_path
)
if
not
bench_args
.
show_report
:
if
not
bench_args
.
show_report
:
return
return
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
print
(
summary
)
if
is_in_ci
():
if
is_in_ci
()
and
bench_args
.
append_to_github_summary
:
write_github_step_summary
(
summary
)
write_github_step_summary
(
summary
)
...
...
python/sglang/bench_serving.py
View file @
777eb538
...
@@ -208,6 +208,10 @@ async def async_request_openai_completions(
...
@@ -208,6 +208,10 @@ async def async_request_openai_completions(
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
**
request_func_input
.
extra_request_body
,
**
request_func_input
.
extra_request_body
,
}
}
if
request_func_input
.
image_data
:
payload
.
update
({
"image_data"
:
request_func_input
.
image_data
})
headers
=
get_auth_headers
()
headers
=
get_auth_headers
()
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
...
@@ -664,7 +668,7 @@ def get_dataset(args, tokenizer):
...
@@ -664,7 +668,7 @@ def get_dataset(args, tokenizer):
num_prompts
=
args
.
num_prompts
,
num_prompts
=
args
.
num_prompts
,
range_ratio
=
args
.
random_range_ratio
,
range_ratio
=
args
.
random_range_ratio
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_path
=
args
.
dataset_
path
,
dataset_path
=
args
.
dataset_
name
,
random_sample
=
args
.
dataset_name
==
"random"
,
random_sample
=
args
.
dataset_name
==
"random"
,
return_text
=
not
tokenize_prompt
,
return_text
=
not
tokenize_prompt
,
)
)
...
...
python/sglang/srt/managers/scheduler_profiler_mixin.py
View file @
777eb538
...
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
...
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def
start_profile
(
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
stage_str
=
f
" for
{
stage
.
name
}
"
if
stage
else
""
logger
.
info
(
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
)
...
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
...
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
stage_suffix
=
f
"-
{
stage
.
name
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
self
.
torch_profiler
.
stop
()
...
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
...
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
if
self
.
profile_in_progress
:
# force trace flush
# force trace flush
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
...
...
python/sglang/test/run_eval.py
View file @
777eb538
...
@@ -60,6 +60,11 @@ def run_eval(args):
...
@@ -60,6 +60,11 @@ def run_eval(args):
from
sglang.test.simple_eval_humaneval
import
HumanEval
from
sglang.test.simple_eval_humaneval
import
HumanEval
eval_obj
=
HumanEval
(
args
.
num_examples
,
args
.
num_threads
)
eval_obj
=
HumanEval
(
args
.
num_examples
,
args
.
num_threads
)
elif
args
.
eval_name
==
"mmmu"
:
# VLM MMMU evaluation with fixed 100 examples by default
from
sglang.test.simple_eval_mmmu_vlm
import
MMMUVLMEval
eval_obj
=
MMMUVLMEval
(
args
.
num_examples
,
args
.
num_threads
)
else
:
else
:
raise
ValueError
(
f
"Invalid eval name:
{
args
.
eval_name
}
"
)
raise
ValueError
(
f
"Invalid eval name:
{
args
.
eval_name
}
"
)
...
@@ -94,6 +99,8 @@ def run_eval(args):
...
@@ -94,6 +99,8 @@ def run_eval(args):
print
(
f
"Total latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Total latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Score:
{
metrics
[
'score'
]:.
3
f
}
"
)
print
(
f
"Score:
{
metrics
[
'score'
]:.
3
f
}
"
)
if
getattr
(
args
,
"return_latency"
,
False
):
return
metrics
,
latency
return
metrics
return
metrics
...
...
python/sglang/test/simple_eval_mmmu_vlm.py
0 → 100644
View file @
777eb538
"""
MMMU evaluation for VLMs using the run_eval simple-evals interface.
"""
from
__future__
import
annotations
import
base64
import
io
from
typing
import
List
,
Optional
,
Tuple
from
datasets
import
concatenate_datasets
,
load_dataset
from
PIL
import
Image
from
sglang.test
import
simple_eval_common
as
common
from
sglang.test.simple_eval_common
import
(
HTML_JINJA
,
Eval
,
EvalResult
,
SamplerBase
,
SingleEvalResult
,
map_with_progress
,
)
class
MMMUVLMEval
(
Eval
):
DOMAIN_CAT2SUB_CAT
=
{
"Art and Design"
:
[
"Art"
,
"Art_Theory"
,
"Design"
,
"Music"
],
"Business"
:
[
"Accounting"
,
"Economics"
,
"Finance"
,
"Manage"
,
"Marketing"
],
"Science"
:
[
"Biology"
,
"Chemistry"
,
"Geography"
,
"Math"
,
"Physics"
],
"Health and Medicine"
:
[
"Basic_Medical_Science"
,
"Clinical_Medicine"
,
"Diagnostics_and_Laboratory_Medicine"
,
"Pharmacy"
,
"Public_Health"
,
],
"Humanities and Social Science"
:
[
"History"
,
"Literature"
,
"Sociology"
,
"Psychology"
,
],
"Tech and Engineering"
:
[
"Agriculture"
,
"Architecture_and_Engineering"
,
"Computer_Science"
,
"Electronics"
,
"Energy_and_Power"
,
"Materials"
,
"Mechanical_Engineering"
,
],
}
def
__init__
(
self
,
num_examples
:
Optional
[
int
]
=
100
,
num_threads
:
int
=
32
,
seed
:
int
=
42
):
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
self
.
num_examples
=
num_examples
self
.
num_threads
=
num_threads
self
.
seed
=
seed
# Prepare samples deterministically across all MMMU subjects (validation split)
self
.
samples
=
self
.
_prepare_mmmu_samples
(
self
.
num_examples
)
@
staticmethod
def
_to_data_uri
(
image
:
Image
.
Image
)
->
str
:
if
image
.
mode
==
"RGBA"
:
image
=
image
.
convert
(
"RGB"
)
buf
=
io
.
BytesIO
()
image
.
save
(
buf
,
format
=
"PNG"
)
b64
=
base64
.
b64encode
(
buf
.
getvalue
()).
decode
(
"utf-8"
)
return
f
"data:image/png;base64,
{
b64
}
"
@
staticmethod
def
_build_mc_mapping
(
options
:
List
[
str
])
->
Tuple
[
dict
,
List
[
str
]]:
index2ans
=
{}
all_choices
=
[]
ch
=
ord
(
"A"
)
for
opt
in
options
:
letter
=
chr
(
ch
)
index2ans
[
letter
]
=
opt
all_choices
.
append
(
letter
)
ch
+=
1
return
index2ans
,
all_choices
def
_prepare_mmmu_samples
(
self
,
k
:
int
)
->
List
[
dict
]:
# Subjects and domains copied from MMMU data_utils to categorize results
subjects
:
List
[
str
]
=
[]
for
subs
in
self
.
DOMAIN_CAT2SUB_CAT
.
values
():
subjects
.
extend
(
subs
)
# Load validation split of each subject
datasets
=
[]
for
subj
in
subjects
:
try
:
d
=
load_dataset
(
"MMMU/MMMU"
,
subj
,
split
=
"validation"
)
# attach subject info via transform
d
=
d
.
add_column
(
"__subject__"
,
[
subj
]
*
len
(
d
))
datasets
.
append
(
d
)
except
Exception
:
continue
if
not
datasets
:
raise
RuntimeError
(
"Failed to load MMMU datasets"
)
merged
=
concatenate_datasets
(
datasets
)
# Deterministic selection: sort by id (fallback to subject+index)
def
_key
(
idx
):
ex
=
merged
[
idx
]
return
str
(
ex
.
get
(
"id"
,
f
"
{
ex
[
'__subject__'
]
}
:
{
idx
}
"
))
order
=
sorted
(
range
(
len
(
merged
)),
key
=
_key
)
picked_indices
=
order
[:
k
]
samples
:
List
[
dict
]
=
[]
for
idx
in
picked_indices
:
ex
=
merged
[
idx
]
subject
=
ex
[
"__subject__"
]
image
=
ex
.
get
(
"image_1"
)
if
image
is
None
or
not
hasattr
(
image
,
"convert"
):
continue
data_uri
=
self
.
_to_data_uri
(
image
)
question
=
ex
.
get
(
"question"
,
""
)
answer
=
ex
.
get
(
"answer"
)
raw_options
=
ex
.
get
(
"options"
)
question_type
=
"open"
index2ans
=
None
all_choices
=
None
options
=
None
if
raw_options
:
try
:
options
=
(
raw_options
if
isinstance
(
raw_options
,
list
)
else
list
(
eval
(
raw_options
))
)
if
isinstance
(
options
,
list
)
and
len
(
options
)
>
0
:
index2ans
,
all_choices
=
self
.
_build_mc_mapping
(
options
)
question_type
=
"multiple-choice"
except
Exception
:
options
=
None
# Build final textual prompt; include choices if MC
prompt_text
=
f
"Question:
{
question
}
\n\n
"
if
options
:
letters
=
[
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
len
(
options
))]
for
letter
,
opt
in
zip
(
letters
,
options
):
prompt_text
+=
f
"
{
letter
}
)
{
opt
}
\n
"
prompt_text
+=
"
\n
Answer: "
samples
.
append
(
{
"id"
:
ex
.
get
(
"id"
,
f
"
{
subject
}
:
{
idx
}
"
),
"final_input_prompt"
:
prompt_text
,
"image_data"
:
data_uri
,
"answer"
:
answer
,
"question_type"
:
question_type
,
"index2ans"
:
index2ans
,
"all_choices"
:
all_choices
,
"category"
:
subject
,
}
)
return
samples
@
staticmethod
def
_split_prompt_for_image
(
prompt
:
str
)
->
tuple
[
str
,
str
]:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
if
"<"
in
prompt
and
">"
in
prompt
:
prefix
=
prompt
.
split
(
"<"
)[
0
]
suffix
=
prompt
.
split
(
">"
,
1
)[
1
]
return
prefix
,
suffix
return
prompt
,
""
@
staticmethod
def
build_chat_messages_from_prompt
(
prompt
:
str
,
image_data
)
->
List
:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
# Build a vision+text message for OpenAI-compatible API
prefix
,
suffix
=
MMMUVLMEval
.
_split_prompt_for_image
(
prompt
)
content
:
List
[
dict
]
=
[]
if
prefix
:
content
.
append
({
"type"
:
"text"
,
"text"
:
prefix
})
content
.
append
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_data
}})
if
suffix
:
content
.
append
({
"type"
:
"text"
,
"text"
:
suffix
})
prompt_messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
return
prompt_messages
def
__call__
(
self
,
sampler
:
SamplerBase
)
->
EvalResult
:
def
fn
(
sample
:
dict
):
prompt
=
sample
[
"final_input_prompt"
]
image_data
=
sample
[
"image_data"
]
prompt_messages
=
MMMUVLMEval
.
build_chat_messages_from_prompt
(
prompt
,
image_data
)
# Sample
response_text
=
sampler
(
prompt_messages
)
# Parse and score
gold
=
sample
[
"answer"
]
if
(
sample
[
"question_type"
]
==
"multiple-choice"
and
sample
[
"all_choices"
]
and
sample
[
"index2ans"
]
):
pred
=
_parse_multi_choice_response
(
response_text
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
)
score
=
1.0
if
(
gold
is
not
None
and
pred
==
gold
)
else
0.0
extracted_answer
=
pred
else
:
parsed_list
=
_parse_open_response
(
response_text
)
score
=
(
1.0
if
(
gold
is
not
None
and
_eval_open
(
gold
,
parsed_list
))
else
0.0
)
extracted_answer
=
", "
.
join
(
map
(
str
,
parsed_list
))
html_rendered
=
common
.
jinja_env
.
from_string
(
HTML_JINJA
).
render
(
prompt_messages
=
prompt_messages
,
next_message
=
dict
(
content
=
response_text
,
role
=
"assistant"
),
score
=
score
,
correct_answer
=
gold
,
extracted_answer
=
extracted_answer
,
)
convo
=
prompt_messages
+
[
dict
(
content
=
response_text
,
role
=
"assistant"
)]
return
SingleEvalResult
(
html
=
html_rendered
,
score
=
score
,
metrics
=
{
"__category__"
:
sample
[
"category"
]},
convo
=
convo
,
)
results
=
map_with_progress
(
fn
,
self
.
samples
,
self
.
num_threads
)
# Build category table and overall accuracy
# Gather per-sample correctness and category
per_cat_total
:
dict
[
str
,
int
]
=
{}
per_cat_correct
:
dict
[
str
,
int
]
=
{}
htmls
=
[]
convos
=
[]
scores
:
List
[
float
]
=
[]
for
r
in
results
:
# __category__ stored under metrics
cat
=
r
.
metrics
.
get
(
"__category__"
)
if
r
.
metrics
else
None
if
cat
is
None
:
cat
=
"Unknown"
per_cat_total
[
cat
]
=
per_cat_total
.
get
(
cat
,
0
)
+
1
if
r
.
score
:
per_cat_correct
[
cat
]
=
per_cat_correct
.
get
(
cat
,
0
)
+
1
htmls
.
append
(
r
.
html
)
convos
.
append
(
r
.
convo
)
if
r
.
score
is
not
None
:
scores
.
append
(
r
.
score
)
evaluation_result
=
{}
for
cat
,
tot
in
per_cat_total
.
items
():
corr
=
per_cat_correct
.
get
(
cat
,
0
)
acc
=
(
corr
/
tot
)
if
tot
>
0
else
0.0
evaluation_result
[
cat
]
=
{
"acc"
:
round
(
acc
,
3
),
"num_example"
:
tot
}
printable_results
=
{}
# Domains first
for
domain
,
cats
in
self
.
DOMAIN_CAT2SUB_CAT
.
items
():
acc_sum
=
0.0
num_sum
=
0
for
cat
in
cats
:
if
cat
in
evaluation_result
:
acc_sum
+=
(
evaluation_result
[
cat
][
"acc"
]
*
evaluation_result
[
cat
][
"num_example"
]
)
num_sum
+=
evaluation_result
[
cat
][
"num_example"
]
if
num_sum
>
0
:
printable_results
[
f
"Overall-
{
domain
}
"
]
=
{
"num"
:
num_sum
,
"acc"
:
round
(
acc_sum
/
num_sum
,
3
),
}
# add each sub-category row if present
for
cat
in
cats
:
if
cat
in
evaluation_result
:
printable_results
[
cat
]
=
{
"num"
:
evaluation_result
[
cat
][
"num_example"
],
"acc"
:
evaluation_result
[
cat
][
"acc"
],
}
# Overall
total_num
=
sum
(
v
[
"num_example"
]
for
v
in
evaluation_result
.
values
())
overall_acc
=
(
sum
(
v
[
"acc"
]
*
v
[
"num_example"
]
for
v
in
evaluation_result
.
values
())
/
total_num
if
total_num
>
0
else
0.0
)
printable_results
[
"Overall"
]
=
{
"num"
:
total_num
,
"acc"
:
round
(
overall_acc
,
3
)}
# Build EvalResult
return
EvalResult
(
score
=
overall_acc
,
metrics
=
printable_results
,
htmls
=
htmls
,
convos
=
convos
)
def
_parse_multi_choice_response
(
response
:
str
,
all_choices
:
List
[
str
],
index2ans
:
dict
)
->
str
:
# loosely adapted from benchmark mmmu eval
for
char
in
[
","
,
"."
,
"!"
,
"?"
,
";"
,
":"
,
"'"
]:
response
=
response
.
strip
(
char
)
response
=
" "
+
response
+
" "
# Prefer explicit letter with bracket e.g. (A)
candidates
:
List
[
str
]
=
[]
for
choice
in
all_choices
:
if
f
"(
{
choice
}
)"
in
response
:
candidates
.
append
(
choice
)
if
not
candidates
:
for
choice
in
all_choices
:
if
f
"
{
choice
}
"
in
response
:
candidates
.
append
(
choice
)
if
not
candidates
and
len
(
response
.
split
())
>
5
:
# try match by option text
for
idx
,
ans
in
index2ans
.
items
():
if
ans
and
ans
.
lower
()
in
response
.
lower
():
candidates
.
append
(
idx
)
if
not
candidates
:
# fallback to first choice
return
all_choices
[
0
]
if
len
(
candidates
)
==
1
:
return
candidates
[
0
]
# choose the last occurrence
starts
=
[]
for
can
in
candidates
:
pos
=
response
.
rfind
(
f
"(
{
can
}
)"
)
if
pos
==
-
1
:
pos
=
response
.
rfind
(
f
"
{
can
}
"
)
if
pos
==
-
1
and
index2ans
.
get
(
can
):
pos
=
response
.
lower
().
rfind
(
index2ans
[
can
].
lower
())
starts
.
append
(
pos
)
return
candidates
[
int
(
max
(
range
(
len
(
starts
)),
key
=
lambda
i
:
starts
[
i
]))]
def
_check_is_number
(
s
:
str
)
->
bool
:
try
:
float
(
s
.
replace
(
","
,
""
))
return
True
except
Exception
:
return
False
def
_normalize_str
(
s
:
str
):
s
=
s
.
strip
()
if
_check_is_number
(
s
):
s
=
s
.
replace
(
","
,
""
)
try
:
v
=
round
(
float
(
s
),
2
)
return
[
v
]
except
Exception
:
return
[
s
.
lower
()]
return
[
s
.
lower
()]
if
len
(
s
)
>
1
else
[
" "
+
s
,
s
+
" "
]
def
_extract_numbers
(
s
:
str
)
->
List
[
str
]:
import
re
as
_re
pattern_commas
=
r
"-?\b\d{1,3}(?:,\d{3})+\b"
pattern_scientific
=
r
"-?\d+(?:\.\d+)?[eE][+-]?\d+"
pattern_simple
=
r
"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
return
(
_re
.
findall
(
pattern_commas
,
s
)
+
_re
.
findall
(
pattern_scientific
,
s
)
+
_re
.
findall
(
pattern_simple
,
s
)
)
def
_parse_open_response
(
response
:
str
)
->
List
[
str
]:
import
re
as
_re
def
get_key_subresponses
(
resp
:
str
)
->
List
[
str
]:
resp
=
resp
.
strip
().
strip
(
"."
).
lower
()
subs
=
_re
.
split
(
r
"\.\s(?=[A-Z])|\n"
,
resp
)
indicators
=
[
"could be "
,
"so "
,
"is "
,
"thus "
,
"therefore "
,
"final "
,
"answer "
,
"result "
,
]
keys
=
[]
for
i
,
s
in
enumerate
(
subs
):
cands
=
[
*
indicators
]
if
i
==
len
(
subs
)
-
1
:
cands
.
append
(
"="
)
shortest
=
None
for
ind
in
cands
:
if
ind
in
s
:
part
=
s
.
split
(
ind
)[
-
1
].
strip
()
if
not
shortest
or
len
(
part
)
<
len
(
shortest
):
shortest
=
part
if
shortest
and
shortest
not
in
[
":"
,
","
,
"."
,
"!"
,
"?"
,
";"
,
":"
,
"'"
]:
keys
.
append
(
shortest
)
return
keys
or
[
resp
]
key_resps
=
get_key_subresponses
(
response
)
pred_list
=
key_resps
.
copy
()
for
r
in
key_resps
:
pred_list
.
extend
(
_extract_numbers
(
r
))
out
=
[]
for
x
in
pred_list
:
out
.
extend
(
_normalize_str
(
x
))
# dedup
return
list
(
dict
.
fromkeys
(
out
))
def
_eval_open
(
gold
,
preds
:
List
[
str
])
->
bool
:
if
isinstance
(
gold
,
list
):
norm_answers
=
[]
for
ans
in
gold
:
norm_answers
.
extend
(
_normalize_str
(
ans
))
else
:
norm_answers
=
_normalize_str
(
gold
)
for
p
in
preds
:
if
isinstance
(
p
,
str
):
for
na
in
norm_answers
:
if
isinstance
(
na
,
str
)
and
na
in
p
:
return
True
else
:
if
p
in
norm_answers
:
return
True
return
False
python/sglang/test/test_utils.py
View file @
777eb538
...
@@ -14,10 +14,12 @@ import time
...
@@ -14,10 +14,12 @@ import time
import
unittest
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
urllib.parse
import
quote
import
aiohttp
import
aiohttp
import
numpy
as
np
import
numpy
as
np
...
@@ -1467,3 +1469,137 @@ def dump_bench_raw_result(
...
@@ -1467,3 +1469,137 @@ def dump_bench_raw_result(
def
_ensure_remove_suffix
(
text
:
str
,
suffix
:
str
):
def
_ensure_remove_suffix
(
text
:
str
,
suffix
:
str
):
assert
text
.
endswith
(
suffix
)
assert
text
.
endswith
(
suffix
)
return
text
.
removesuffix
(
suffix
)
return
text
.
removesuffix
(
suffix
)
class
ModelDeploySetup
:
def
__init__
(
self
,
model_path
:
str
,
extra_args
:
List
[
str
]
=
[]):
self
.
model_path
=
model_path
if
"--enable-multimodal"
not
in
extra_args
:
extra_args
.
append
(
"--enable-multimodal"
)
if
"--trust-remote-code"
not
in
extra_args
:
extra_args
.
append
(
"--trust-remote-code"
)
self
.
extra_args
=
extra_args
class
ModelEvalMetrics
:
def
__init__
(
self
,
accuracy
:
float
,
eval_time
:
float
):
self
.
accuracy
=
accuracy
self
.
eval_time
=
eval_time
def
extract_trace_link_from_bench_one_batch_server_output
(
output
:
str
)
->
str
:
match
=
re
.
search
(
r
"\[Profile\]\((.*?)\)"
,
output
)
if
match
:
trace_link
=
match
.
group
(
1
)
return
trace_link
return
None
def
parse_models
(
model_string
:
str
):
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
def
check_evaluation_test_results
(
results
,
test_name
,
model_accuracy_thresholds
,
model_latency_thresholds
=
None
,
model_count
=
None
,
):
"""
results: list of tuple of (model_path, accuracy, latency)
"""
failed_models
=
[]
if
model_latency_thresholds
is
not
None
:
summary
=
" | model | status | score | score_threshold | latency | latency_threshold |
\n
"
summary
+=
"| ----- | ------ | ----- | --------------- | ------- | ----------------- |
\n
"
else
:
summary
=
" | model | status | score | score_threshold |
\n
"
summary
+=
"| ----- | ------ | ----- | --------------- |
\n
"
for
model
,
accuracy
,
latency
in
results
:
accuracy_threshold
=
model_accuracy_thresholds
.
get
(
model
)
if
accuracy_threshold
is
None
:
print
(
f
"Warning: No threshold defined for model
{
model
}
"
)
continue
latency_threshold
=
(
model_latency_thresholds
.
get
(
model
,
None
)
if
model_latency_thresholds
else
1e9
)
is_success
=
accuracy
>=
accuracy_threshold
and
latency
<=
latency_threshold
status_emoji
=
"✅"
if
is_success
else
"❌"
if
not
is_success
:
failed_models
.
append
(
f
"
\n
Score Check Failed:
{
model
}
\n
"
f
"Model
{
model
}
score (
{
accuracy
:.
4
f
}
) is below threshold (
{
accuracy_threshold
:.
4
f
}
)"
)
if
model_latency_thresholds
is
not
None
:
line
=
f
"|
{
model
}
|
{
status_emoji
}
|
{
accuracy
}
|
{
accuracy_threshold
}
|
{
latency
}
|
{
latency_threshold
}
\n
"
else
:
line
=
f
"|
{
model
}
|
{
status_emoji
}
|
{
accuracy
}
|
{
accuracy_threshold
}
\n
"
summary
+=
line
print
(
summary
)
if
is_in_ci
():
write_github_step_summary
(
f
"##
{
test_name
}
\n
{
summary
}
"
)
some_model_failed_to_get_result
=
len
(
results
)
!=
(
model_count
or
len
(
model_accuracy_thresholds
)
)
if
some_model_failed_to_get_result
:
print
(
"Some model has failed to launch and be evaluated"
)
if
failed_models
or
some_model_failed_to_get_result
:
raise
AssertionError
(
"
\n
"
.
join
(
failed_models
))
# Bench knobs for bench_one_batch_server (override by env)
def
_parse_int_list_env
(
name
:
str
,
default_val
:
str
):
val
=
os
.
environ
.
get
(
name
,
default_val
)
return
[
int
(
x
)
for
x
in
val
.
split
(
","
)
if
x
]
# Return filenames
def
find_traces_under_path
(
path
:
str
)
->
List
[
str
]:
results
=
[]
for
_
,
dirs
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
if
file
.
endswith
(
".trace.json.gz"
):
results
.
append
(
f
"
{
file
}
"
)
return
results
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
if
"latency"
in
metrics
:
result
[
"latency"
]
=
(
metrics
.
get
(
"latency"
),)
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
scripts/ci/publish_traces.py
0 → 100644
View file @
777eb538
"""
Publish performance traces to GitHub repository
"""
import
argparse
import
base64
import
json
import
os
import
sys
from
urllib.request
import
Request
,
urlopen
def
make_github_request
(
url
,
token
,
method
=
"GET"
,
data
=
None
):
"""Make authenticated request to GitHub API"""
headers
=
{
"Accept"
:
"application/vnd.github+json"
,
"Authorization"
:
f
"Bearer
{
token
}
"
,
# "User-Agent": "sglang-ci",
"X-GitHub-Api-Version"
:
"2022-11-28"
,
}
if
data
:
headers
[
"Content-Type"
]
=
"application/json"
data
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
req
=
Request
(
url
,
data
=
data
,
headers
=
headers
,
method
=
method
)
try
:
with
urlopen
(
req
)
as
response
:
return
response
.
read
().
decode
(
"utf-8"
)
except
Exception
as
e
:
print
(
f
"GitHub API request failed:
{
e
}
"
)
if
hasattr
(
e
,
"read"
):
try
:
error_body
=
e
.
read
().
decode
(
"utf-8"
)
print
(
f
"Error response body:
{
error_body
}
"
)
except
:
pass
raise
def
verify_token_permissions
(
repo_owner
,
repo_name
,
token
):
"""Verify that the token has necessary permissions for the repository"""
print
(
"Verifying token permissions..."
)
# Check if we can access the repository
try
:
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
"
response
=
make_github_request
(
url
,
token
)
repo_data
=
json
.
loads
(
response
)
print
(
f
"Repository access verified:
{
repo_data
[
'full_name'
]
}
"
)
except
Exception
as
e
:
print
(
f
"Failed to access repository:
{
e
}
"
)
return
False
# Check if we can read the repository contents
try
:
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/contents"
response
=
make_github_request
(
url
,
token
)
print
(
"Repository contents access verified"
)
except
Exception
as
e
:
print
(
f
"Failed to access repository contents:
{
e
}
"
)
return
False
return
True
def
get_branch_sha
(
repo_owner
,
repo_name
,
branch
,
token
):
"""Get SHA of the branch head"""
url
=
(
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/refs/heads/
{
branch
}
"
)
response
=
make_github_request
(
url
,
token
)
data
=
json
.
loads
(
response
)
return
data
[
"object"
][
"sha"
]
def
get_tree_sha
(
repo_owner
,
repo_name
,
commit_sha
,
token
):
"""Get tree SHA from commit"""
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/commits/
{
commit_sha
}
"
response
=
make_github_request
(
url
,
token
)
data
=
json
.
loads
(
response
)
return
data
[
"tree"
][
"sha"
]
def
create_blob
(
repo_owner
,
repo_name
,
content
,
token
):
"""Create a blob with file content"""
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/blobs"
# Encode content as base64 for GitHub API
content_b64
=
base64
.
b64encode
(
content
).
decode
(
"utf-8"
)
data
=
{
"content"
:
content_b64
,
"encoding"
:
"base64"
}
response
=
make_github_request
(
url
,
token
,
method
=
"POST"
,
data
=
data
)
return
json
.
loads
(
response
)[
"sha"
]
def
create_tree
(
repo_owner
,
repo_name
,
base_tree_sha
,
files
,
token
):
"""Create a new tree with files"""
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/trees"
tree_items
=
[]
for
file_path
,
content
in
files
:
# Create blob first to get SHA
blob_sha
=
create_blob
(
repo_owner
,
repo_name
,
content
,
token
)
tree_items
.
append
(
{
"path"
:
file_path
,
"mode"
:
"100644"
,
"type"
:
"blob"
,
"sha"
:
blob_sha
,
}
)
data
=
{
"base_tree"
:
base_tree_sha
,
"tree"
:
tree_items
}
response
=
make_github_request
(
url
,
token
,
method
=
"POST"
,
data
=
data
)
return
json
.
loads
(
response
)[
"sha"
]
def
create_commit
(
repo_owner
,
repo_name
,
tree_sha
,
parent_sha
,
message
,
token
):
"""Create a new commit"""
url
=
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/commits"
data
=
{
"tree"
:
tree_sha
,
"parents"
:
[
parent_sha
],
"message"
:
message
}
response
=
make_github_request
(
url
,
token
,
method
=
"POST"
,
data
=
data
)
return
json
.
loads
(
response
)[
"sha"
]
def
update_branch_ref
(
repo_owner
,
repo_name
,
branch
,
commit_sha
,
token
):
"""Update branch reference to point to new commit"""
url
=
(
f
"https://api.github.com/repos/
{
repo_owner
}
/
{
repo_name
}
/git/refs/heads/
{
branch
}
"
)
data
=
{
"sha"
:
commit_sha
}
make_github_request
(
url
,
token
,
method
=
"PATCH"
,
data
=
data
)
def
copy_trace_files
(
source_dir
,
target_base_path
,
is_vlm
=
False
):
"""Copy trace files and return list of files to upload"""
files_to_upload
=
[]
if
not
os
.
path
.
exists
(
source_dir
):
print
(
f
"Warning: Traces directory
{
source_dir
}
does not exist"
)
return
files_to_upload
# Walk through source directory and find .json.gz files
for
root
,
dirs
,
files
in
os
.
walk
(
source_dir
):
for
file
in
files
:
if
file
.
endswith
(
".json.gz"
):
source_file
=
os
.
path
.
join
(
root
,
file
)
# Calculate relative path from source_dir
rel_path
=
os
.
path
.
relpath
(
source_file
,
source_dir
)
target_path
=
f
"
{
target_base_path
}
/
{
rel_path
}
"
# Read file content
with
open
(
source_file
,
"rb"
)
as
f
:
content
=
f
.
read
()
files_to_upload
.
append
((
target_path
,
content
))
return
files_to_upload
def
publish_traces
(
traces_dir
,
run_id
,
run_number
,
is_vlm
=
False
):
"""Publish traces to GitHub repository in a single commit"""
# Get environment variables
token
=
os
.
getenv
(
"GITHUB_TOKEN"
)
if
not
token
:
print
(
"Error: GITHUB_TOKEN environment variable not set"
)
sys
.
exit
(
1
)
# Repository configuration
repo_owner
=
"sglang-bot"
repo_name
=
"sglang-ci-data"
branch
=
"main"
target_base_path
=
f
"traces/
{
run_id
}
"
# Copy trace files
files_to_upload
=
copy_trace_files
(
traces_dir
,
target_base_path
,
is_vlm
)
if
not
files_to_upload
:
print
(
"No trace files found to upload"
)
return
print
(
f
"Found
{
len
(
files_to_upload
)
}
files to upload"
)
# Verify token permissions before proceeding
if
not
verify_token_permissions
(
repo_owner
,
repo_name
,
token
):
print
(
"Token permission verification failed. Please check the token permissions."
)
sys
.
exit
(
1
)
try
:
# Get current branch head
branch_sha
=
get_branch_sha
(
repo_owner
,
repo_name
,
branch
,
token
)
print
(
f
"Current branch head:
{
branch_sha
}
"
)
# Get current tree
tree_sha
=
get_tree_sha
(
repo_owner
,
repo_name
,
branch_sha
,
token
)
print
(
f
"Current tree SHA:
{
tree_sha
}
"
)
# Create new tree with all files
new_tree_sha
=
create_tree
(
repo_owner
,
repo_name
,
tree_sha
,
files_to_upload
,
token
)
print
(
f
"Created new tree:
{
new_tree_sha
}
"
)
# Create commit
commit_message
=
f
"Nightly traces for run
{
run_id
}
at
{
run_number
}
(
{
len
(
files_to_upload
)
}
files)"
commit_sha
=
create_commit
(
repo_owner
,
repo_name
,
new_tree_sha
,
branch_sha
,
commit_message
,
token
)
print
(
f
"Created commit:
{
commit_sha
}
"
)
# Update branch reference
update_branch_ref
(
repo_owner
,
repo_name
,
branch
,
commit_sha
,
token
)
print
(
"Updated branch reference"
)
print
(
"Successfully published all traces in a single commit"
)
except
Exception
as
e
:
print
(
f
"Failed to publish traces:
{
e
}
"
)
raise
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Publish performance traces to GitHub repository"
)
parser
.
add_argument
(
"--vlm"
,
action
=
"store_true"
,
help
=
"Process VLM model traces"
)
args
=
parser
.
parse_args
()
# Get environment variables
run_id
=
os
.
getenv
(
"GITHUB_RUN_ID"
,
"test"
)
run_number
=
os
.
getenv
(
"GITHUB_RUN_NUMBER"
,
"12345"
)
if
not
run_id
or
not
run_number
:
print
(
"Error: GITHUB_RUN_ID and GITHUB_RUN_NUMBER environment variables must be set"
)
sys
.
exit
(
1
)
# Determine traces directory
if
args
.
vlm
:
traces_dir
=
"performance_profiles_vlms"
print
(
"Processing VLM model traces"
)
else
:
traces_dir
=
"performance_profiles_text_models"
print
(
"Processing text model traces"
)
# Publish traces
publish_traces
(
traces_dir
,
run_id
,
run_number
,
args
.
vlm
)
if
__name__
==
"__main__"
:
main
()
test/srt/run_suite.py
View file @
777eb538
...
@@ -165,9 +165,6 @@ suites = {
...
@@ -165,9 +165,6 @@ suites = {
"per-commit-8-gpu-h20"
:
[
"per-commit-8-gpu-h20"
:
[
TestFile
(
"quant/test_w4a8_deepseek_v3.py"
,
371
),
TestFile
(
"quant/test_w4a8_deepseek_v3.py"
,
371
),
],
],
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
],
"vllm_dependency_test"
:
[
"vllm_dependency_test"
:
[
TestFile
(
"quant/test_awq.py"
,
163
),
TestFile
(
"quant/test_awq.py"
,
163
),
TestFile
(
"test_bnb.py"
,
5
),
TestFile
(
"test_bnb.py"
,
5
),
...
...
test/srt/test_nightly_gsm8k_eval_amd.py
View file @
777eb538
...
@@ -15,8 +15,10 @@ from sglang.test.test_utils import (
...
@@ -15,8 +15,10 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
is_in_ci
,
parse_models
,
popen_launch_server
,
popen_launch_server
,
write_github_step_summary
,
write_github_step_summary
,
write_results_to_json
,
)
)
MODEL_SCORE_THRESHOLDS
=
{
MODEL_SCORE_THRESHOLDS
=
{
...
@@ -73,10 +75,6 @@ TRITON_MOE_MODELS = {
...
@@ -73,10 +75,6 @@ TRITON_MOE_MODELS = {
}
}
def
parse_models
(
model_string
):
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
def
popen_launch_server_wrapper
(
base_url
,
model
,
is_tp2
):
def
popen_launch_server_wrapper
(
base_url
,
model
,
is_tp2
):
other_args
=
[
"--log-level-http"
,
"warning"
,
"--trust-remote-code"
]
other_args
=
[
"--log-level-http"
,
"warning"
,
"--trust-remote-code"
]
if
is_tp2
:
if
is_tp2
:
...
@@ -91,31 +89,6 @@ def popen_launch_server_wrapper(base_url, model, is_tp2):
...
@@ -91,31 +89,6 @@ def popen_launch_server_wrapper(base_url, model, is_tp2):
return
process
return
process
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
def
check_model_scores
(
results
):
def
check_model_scores
(
results
):
failed_models
=
[]
failed_models
=
[]
summary
=
" | model | score | threshold |
\n
"
summary
=
" | model | score | threshold |
\n
"
...
...
test/srt/test_nightly_gsm8k_eval.py
→
test/srt/test_nightly_
text_models_
gsm8k_eval.py
View file @
777eb538
import
json
import
json
import
os
import
unittest
import
unittest
import
warnings
import
warnings
from
datetime
import
datetime
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
...
@@ -14,9 +12,10 @@ from sglang.test.test_utils import (
...
@@ -14,9 +12,10 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2
,
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
check_evaluation_test_results
,
parse_models
,
popen_launch_server
,
popen_launch_server
,
write_
github_step_summary
,
write_
results_to_json
,
)
)
MODEL_SCORE_THRESHOLDS
=
{
MODEL_SCORE_THRESHOLDS
=
{
...
@@ -25,11 +24,11 @@ MODEL_SCORE_THRESHOLDS = {
...
@@ -25,11 +24,11 @@ MODEL_SCORE_THRESHOLDS = {
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
0.85
,
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
0.85
,
"google/gemma-2-27b-it"
:
0.91
,
"google/gemma-2-27b-it"
:
0.91
,
"meta-llama/Llama-3.1-70B-Instruct"
:
0.95
,
"meta-llama/Llama-3.1-70B-Instruct"
:
0.95
,
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
0.6
4
,
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
0.6
2
,
"Qwen/Qwen2-57B-A14B-Instruct"
:
0.86
,
"Qwen/Qwen2-57B-A14B-Instruct"
:
0.86
,
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
:
0.83
,
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
:
0.83
,
"neuralmagic/Mistral-7B-Instruct-v0.3-FP8"
:
0.54
,
"neuralmagic/Mistral-7B-Instruct-v0.3-FP8"
:
0.54
,
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
:
0.8
4
,
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
:
0.8
35
,
"zai-org/GLM-4.5-Air-FP8"
:
0.75
,
"zai-org/GLM-4.5-Air-FP8"
:
0.75
,
# The threshold of neuralmagic/gemma-2-2b-it-FP8 should be 0.6, but this model has some accuracy regression.
# The threshold of neuralmagic/gemma-2-2b-it-FP8 should be 0.6, but this model has some accuracy regression.
# The fix is tracked at https://github.com/sgl-project/sglang/issues/4324, we set it to 0.50, for now, to make CI green.
# The fix is tracked at https://github.com/sgl-project/sglang/issues/4324, we set it to 0.50, for now, to make CI green.
...
@@ -41,78 +40,6 @@ MODEL_SCORE_THRESHOLDS = {
...
@@ -41,78 +40,6 @@ MODEL_SCORE_THRESHOLDS = {
}
}
def
parse_models
(
model_string
):
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
def
popen_launch_server_wrapper
(
base_url
,
model
,
is_tp2
):
other_args
=
[
"--log-level-http"
,
"warning"
,
"--trust-remote-code"
]
if
is_tp2
:
other_args
.
extend
([
"--tp"
,
"2"
])
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
return
process
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
def
check_model_scores
(
results
):
failed_models
=
[]
summary
=
" | model | score | threshold |
\n
"
summary
+=
"| ----- | ----- | --------- |
\n
"
for
model
,
score
in
results
:
threshold
=
MODEL_SCORE_THRESHOLDS
.
get
(
model
)
if
threshold
is
None
:
print
(
f
"Warning: No threshold defined for model
{
model
}
"
)
continue
if
score
<
threshold
:
failed_models
.
append
(
f
"
\n
Score Check Failed:
{
model
}
\n
"
f
"Model
{
model
}
score (
{
score
:.
4
f
}
) is below threshold (
{
threshold
:.
4
f
}
)"
)
line
=
f
"|
{
model
}
|
{
score
}
|
{
threshold
}
|
\n
"
summary
+=
line
print
(
summary
)
if
is_in_ci
():
write_github_step_summary
(
f
"### TestNightlyGsm8KEval
\n
{
summary
}
"
)
if
failed_models
:
raise
AssertionError
(
"
\n
"
.
join
(
failed_models
))
# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry
# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry
class
TestNightlyGsm8KEval
(
unittest
.
TestCase
):
class
TestNightlyGsm8KEval
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
...
@@ -131,11 +58,17 @@ class TestNightlyGsm8KEval(unittest.TestCase):
...
@@ -131,11 +58,17 @@ class TestNightlyGsm8KEval(unittest.TestCase):
)
)
is_first
=
True
is_first
=
True
all_results
=
[]
all_results
=
[]
model_count
=
0
for
model_group
,
is_fp8
,
is_tp2
in
self
.
model_groups
:
for
model_group
,
is_fp8
,
is_tp2
in
self
.
model_groups
:
for
model
in
model_group
:
for
model
in
model_group
:
model_count
+=
1
with
self
.
subTest
(
model
=
model
):
with
self
.
subTest
(
model
=
model
):
process
=
popen_launch_server_wrapper
(
self
.
base_url
,
model
,
is_tp2
)
process
=
popen_launch_server
(
model
=
model
,
base_url
=
self
.
base_url
,
other_args
=
[
"--tp"
,
"2"
]
if
is_tp2
else
[],
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
base_url
=
self
.
base_url
,
...
@@ -153,7 +86,8 @@ class TestNightlyGsm8KEval(unittest.TestCase):
...
@@ -153,7 +86,8 @@ class TestNightlyGsm8KEval(unittest.TestCase):
write_results_to_json
(
model
,
metrics
,
"w"
if
is_first
else
"a"
)
write_results_to_json
(
model
,
metrics
,
"w"
if
is_first
else
"a"
)
is_first
=
False
is_first
=
False
all_results
.
append
((
model
,
metrics
[
"score"
]))
# 0.0 for empty latency
all_results
.
append
((
model
,
metrics
[
"score"
],
0.0
))
kill_process_tree
(
process
.
pid
)
kill_process_tree
(
process
.
pid
)
try
:
try
:
...
@@ -164,7 +98,12 @@ class TestNightlyGsm8KEval(unittest.TestCase):
...
@@ -164,7 +98,12 @@ class TestNightlyGsm8KEval(unittest.TestCase):
print
(
f
"Error reading results.json:
{
e
}
"
)
print
(
f
"Error reading results.json:
{
e
}
"
)
# Check all scores after collecting all results
# Check all scores after collecting all results
check_model_scores
(
all_results
)
check_evaluation_test_results
(
all_results
,
self
.
__class__
.
__name__
,
model_accuracy_thresholds
=
MODEL_SCORE_THRESHOLDS
,
model_count
=
model_count
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_nightly_text_models_perf.py
0 → 100644
View file @
777eb538
import
os
import
subprocess
import
time
import
unittest
from
sglang.bench_one_batch_server
import
BenchmarkResult
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
_parse_int_list_env
,
is_in_ci
,
parse_models
,
popen_launch_server
,
write_github_step_summary
,
)
PROFILE_DIR
=
"performance_profiles_text_models"
class
TestNightlyTextModelsPerformance
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model_groups
=
[
(
parse_models
(
"meta-llama/Llama-3.1-8B-Instruct"
),
False
,
False
),
(
parse_models
(
"Qwen/Qwen2-57B-A14B-Instruct"
),
False
,
True
),
# (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False),
# (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True),
# (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False),
# (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True),
]
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
batch_sizes
=
[
1
,
1
,
8
,
16
,
64
]
cls
.
input_lens
=
tuple
(
_parse_int_list_env
(
"NIGHTLY_INPUT_LENS"
,
"4096"
))
cls
.
output_lens
=
tuple
(
_parse_int_list_env
(
"NIGHTLY_OUTPUT_LENS"
,
"512"
))
os
.
makedirs
(
PROFILE_DIR
,
exist_ok
=
True
)
cls
.
full_report
=
f
"##
{
cls
.
__name__
}
\n
"
+
BenchmarkResult
.
help_str
()
def
test_bench_one_batch
(
self
):
all_benchmark_results
=
[]
for
model_group
,
is_fp8
,
is_tp2
in
self
.
model_groups
:
for
model
in
model_group
:
benchmark_results
=
[]
with
self
.
subTest
(
model
=
model
):
process
=
popen_launch_server
(
model
=
model
,
base_url
=
self
.
base_url
,
other_args
=
[
"--tp"
,
"2"
]
if
is_tp2
else
[],
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
try
:
profile_filename
=
(
f
"
{
model
.
replace
(
'/'
,
'_'
)
}
_
{
int
(
time
.
time
())
}
"
)
profile_path_prefix
=
os
.
path
.
join
(
PROFILE_DIR
,
profile_filename
)
json_output_file
=
(
f
"results_
{
model
.
replace
(
'/'
,
'_'
)
}
_
{
int
(
time
.
time
())
}
.json"
)
command
=
[
"python3"
,
"-m"
,
"sglang.bench_one_batch_server"
,
"--model"
,
model
,
"--base-url"
,
self
.
base_url
,
"--batch-size"
,
*
[
str
(
x
)
for
x
in
self
.
batch_sizes
],
"--input-len"
,
*
[
str
(
x
)
for
x
in
self
.
input_lens
],
"--output-len"
,
*
[
str
(
x
)
for
x
in
self
.
output_lens
],
"--show-report"
,
"--profile"
,
"--profile-by-stage"
,
"--profile-filename-prefix"
,
profile_path_prefix
,
f
"--output-path=
{
json_output_file
}
"
,
"--no-append-to-github-summary"
,
]
print
(
f
"Running command:
{
' '
.
join
(
command
)
}
"
)
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
if
result
.
returncode
!=
0
:
print
(
f
"Error running benchmark for
{
model
}
with batch size:"
)
print
(
result
.
stderr
)
# Continue to next batch size even if one fails
continue
# Load and deserialize JSON results
if
os
.
path
.
exists
(
json_output_file
):
import
json
with
open
(
json_output_file
,
"r"
)
as
f
:
json_data
=
json
.
load
(
f
)
# Convert JSON data to BenchmarkResult objects
for
data
in
json_data
:
benchmark_result
=
BenchmarkResult
(
**
data
)
all_benchmark_results
.
append
(
benchmark_result
)
benchmark_results
.
append
(
benchmark_result
)
print
(
f
"Loaded
{
len
(
benchmark_results
)
}
benchmark results from
{
json_output_file
}
"
)
# Clean up JSON file
os
.
remove
(
json_output_file
)
else
:
print
(
f
"Warning: JSON output file
{
json_output_file
}
not found"
)
finally
:
kill_process_tree
(
process
.
pid
)
report_part
=
BenchmarkResult
.
generate_markdown_report
(
PROFILE_DIR
,
benchmark_results
)
self
.
full_report
+=
report_part
+
"
\n
"
if
is_in_ci
():
write_github_step_summary
(
self
.
full_report
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_nightly_vlms_mmmu_eval.py
0 → 100644
View file @
777eb538
import
json
import
unittest
import
warnings
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
ModelDeploySetup
,
ModelEvalMetrics
,
check_evaluation_test_results
,
popen_launch_server
,
write_results_to_json
,
)
MODEL_THRESHOLDS
=
{
# Conservative thresholds on 100 MMMU samples, especially for latency thresholds
ModelDeploySetup
(
"deepseek-ai/deepseek-vl2-small"
):
ModelEvalMetrics
(
0.330
,
56.1
),
ModelDeploySetup
(
"deepseek-ai/Janus-Pro-7B"
):
ModelEvalMetrics
(
0.285
,
39.9
),
ModelDeploySetup
(
"Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
):
ModelEvalMetrics
(
0.305
,
23.8
),
ModelDeploySetup
(
"google/gemma-3-4b-it"
):
ModelEvalMetrics
(
0.360
,
10.9
),
ModelDeploySetup
(
"google/gemma-3n-E4B-it"
):
ModelEvalMetrics
(
0.360
,
15.3
),
ModelDeploySetup
(
"mistral-community/pixtral-12b"
):
ModelEvalMetrics
(
0.360
,
14.5
),
ModelDeploySetup
(
"moonshotai/Kimi-VL-A3B-Instruct"
):
ModelEvalMetrics
(
0.330
,
22.3
),
ModelDeploySetup
(
"openbmb/MiniCPM-o-2_6"
):
ModelEvalMetrics
(
0.330
,
29.3
),
ModelDeploySetup
(
"openbmb/MiniCPM-v-2_6"
):
ModelEvalMetrics
(
0.270
,
24.5
),
ModelDeploySetup
(
"OpenGVLab/InternVL2_5-2B"
):
ModelEvalMetrics
(
0.300
,
14.0
),
ModelDeploySetup
(
"Qwen/Qwen2-VL-7B-Instruct"
):
ModelEvalMetrics
(
0.310
,
83.3
),
ModelDeploySetup
(
"Qwen/Qwen2.5-VL-7B-Instruct"
):
ModelEvalMetrics
(
0.340
,
31.9
),
ModelDeploySetup
(
"unsloth/Mistral-Small-3.1-24B-Instruct-2503"
):
ModelEvalMetrics
(
0.310
,
16.7
),
ModelDeploySetup
(
"XiaomiMiMo/MiMo-VL-7B-RL"
):
ModelEvalMetrics
(
0.28
,
32.0
),
ModelDeploySetup
(
"zai-org/GLM-4.1V-9B-Thinking"
):
ModelEvalMetrics
(
0.280
,
30.4
),
}
class
TestNightlyVLMMmmuEval
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
list
(
MODEL_THRESHOLDS
.
keys
())
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
def
test_mmmu_vlm_models
(
self
):
warnings
.
filterwarnings
(
"ignore"
,
category
=
ResourceWarning
,
message
=
"unclosed.*socket"
)
is_first
=
True
all_results
=
[]
for
model
in
self
.
models
:
model_path
=
model
.
model_path
with
self
.
subTest
(
model
=
model_path
):
process
=
popen_launch_server
(
model
=
model_path
,
base_url
=
self
.
base_url
,
other_args
=
model
.
extra_args
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
try
:
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
model_path
,
eval_name
=
"mmmu"
,
num_examples
=
100
,
num_threads
=
64
,
max_tokens
=
30
,
)
args
.
return_latency
=
True
metrics
,
latency
=
run_eval
(
args
)
metrics
[
"score"
]
=
round
(
metrics
[
"score"
],
4
)
metrics
[
"latency"
]
=
round
(
latency
,
4
)
print
(
f
"
{
'='
*
42
}
\n
{
model_path
}
- metrics=
{
metrics
}
score=
{
metrics
[
'score'
]
}
\n
{
'='
*
42
}
\n
"
)
write_results_to_json
(
model_path
,
metrics
,
"w"
if
is_first
else
"a"
)
is_first
=
False
all_results
.
append
(
(
model_path
,
metrics
[
"score"
],
metrics
[
"latency"
])
)
finally
:
kill_process_tree
(
process
.
pid
)
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
print
(
"
\n
Final Results from results.json:"
)
print
(
json
.
dumps
(
json
.
load
(
f
),
indent
=
2
))
except
Exception
as
e
:
print
(
f
"Error reading results:
{
e
}
"
)
model_accuracy_thresholds
=
{
model
.
model_path
:
threshold
.
accuracy
for
model
,
threshold
in
MODEL_THRESHOLDS
.
items
()
}
model_latency_thresholds
=
{
model
.
model_path
:
threshold
.
eval_time
for
model
,
threshold
in
MODEL_THRESHOLDS
.
items
()
}
check_evaluation_test_results
(
all_results
,
self
.
__class__
.
__name__
,
model_accuracy_thresholds
=
model_accuracy_thresholds
,
model_latency_thresholds
=
model_latency_thresholds
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_nightly_vlms_perf.py
0 → 100644
View file @
777eb538
import
os
import
subprocess
import
unittest
import
warnings
from
sglang.bench_one_batch_server
import
BenchmarkResult
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
_parse_int_list_env
,
is_in_ci
,
parse_models
,
popen_launch_server
,
write_github_step_summary
,
)
PROFILE_DIR
=
"performance_profiles_vlms"
MODEL_DEFAULTS
=
[
# Keep conservative defaults. Can be overridden by env NIGHTLY_VLM_MODELS
"Qwen/Qwen2.5-VL-7B-Instruct"
,
"google/gemma-3-27b-it"
,
# "OpenGVLab/InternVL2_5-2B",
# buggy in official transformers impl
# "openbmb/MiniCPM-V-2_6",
]
class
TestNightlyVLMModelsPerformance
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
warnings
.
filterwarnings
(
"ignore"
,
category
=
ResourceWarning
,
message
=
"unclosed.*socket"
)
cls
.
models
=
parse_models
(
os
.
environ
.
get
(
"NIGHTLY_VLM_MODELS"
,
","
.
join
(
MODEL_DEFAULTS
))
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
batch_sizes
=
_parse_int_list_env
(
"NIGHTLY_VLM_BATCH_SIZES"
,
"1,1,2,8,16"
)
cls
.
input_lens
=
tuple
(
_parse_int_list_env
(
"NIGHTLY_VLM_INPUT_LENS"
,
"4096"
))
cls
.
output_lens
=
tuple
(
_parse_int_list_env
(
"NIGHTLY_VLM_OUTPUT_LENS"
,
"512"
))
cls
.
full_report
=
f
"##
{
cls
.
__name__
}
\n
"
+
BenchmarkResult
.
help_str
()
def
test_bench_one_batch
(
self
):
all_benchmark_results
=
[]
for
model
in
self
.
models
:
benchmark_results
=
[]
with
self
.
subTest
(
model
=
model
):
process
=
popen_launch_server
(
model
=
model
,
base_url
=
self
.
base_url
,
other_args
=
[
"--mem-fraction-static=0.7"
],
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
try
:
# Run bench_one_batch_server against the launched server
profile_filename
=
f
"
{
model
.
replace
(
'/'
,
'_'
)
}
"
# path for this run
profile_path_prefix
=
os
.
path
.
join
(
PROFILE_DIR
,
profile_filename
)
# JSON output file for this model
json_output_file
=
f
"results_
{
model
.
replace
(
'/'
,
'_'
)
}
.json"
command
=
[
"python3"
,
"-m"
,
"sglang.bench_one_batch_server"
,
f
"--model=
{
model
}
"
,
"--base-url"
,
self
.
base_url
,
"--batch-size"
,
*
[
str
(
x
)
for
x
in
self
.
batch_sizes
],
"--input-len"
,
*
[
str
(
x
)
for
x
in
self
.
input_lens
],
"--output-len"
,
*
[
str
(
x
)
for
x
in
self
.
output_lens
],
"--trust-remote-code"
,
"--dataset-name=mmmu"
,
"--profile"
,
"--profile-by-stage"
,
f
"--profile-filename-prefix=
{
profile_path_prefix
}
"
,
"--show-report"
,
f
"--output-path=
{
json_output_file
}
"
,
"--no-append-to-github-summary"
,
]
print
(
f
"Running command:
{
' '
.
join
(
command
)
}
"
)
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
if
result
.
returncode
!=
0
:
print
(
f
"Error running benchmark for
{
model
}
with batch size:"
)
print
(
result
.
stderr
)
# Continue to next batch size even if one fails
continue
print
(
f
"Output for
{
model
}
with batch size:"
)
print
(
result
.
stdout
)
# Load and deserialize JSON results
if
os
.
path
.
exists
(
json_output_file
):
import
json
with
open
(
json_output_file
,
"r"
)
as
f
:
json_data
=
json
.
load
(
f
)
# Convert JSON data to BenchmarkResult objects
for
data
in
json_data
:
benchmark_result
=
BenchmarkResult
(
**
data
)
all_benchmark_results
.
append
(
benchmark_result
)
benchmark_results
.
append
(
benchmark_result
)
print
(
f
"Loaded
{
len
(
benchmark_results
)
}
benchmark results from
{
json_output_file
}
"
)
else
:
print
(
f
"Warning: JSON output file
{
json_output_file
}
not found"
)
finally
:
kill_process_tree
(
process
.
pid
)
report_part
=
BenchmarkResult
.
generate_markdown_report
(
PROFILE_DIR
,
benchmark_results
)
self
.
full_report
+=
report_part
+
"
\n
"
if
is_in_ci
():
write_github_step_summary
(
self
.
full_report
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_vllm_dependency.py
View file @
777eb538
...
@@ -14,6 +14,7 @@ from sglang.test.test_utils import (
...
@@ -14,6 +14,7 @@ from sglang.test.test_utils import (
is_in_ci
,
is_in_ci
,
popen_launch_server
,
popen_launch_server
,
write_github_step_summary
,
write_github_step_summary
,
write_results_to_json
,
)
)
MODEL_SCORE_THRESHOLDS
=
{
MODEL_SCORE_THRESHOLDS
=
{
...
@@ -52,31 +53,6 @@ def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
...
@@ -52,31 +53,6 @@ def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
return
process
return
process
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
def
check_model_scores
(
results
):
def
check_model_scores
(
results
):
failed_models
=
[]
failed_models
=
[]
summary
=
" | model | score | threshold |
\n
"
summary
=
" | model | score | threshold |
\n
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment