Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62222bd2
Unverified
Commit
62222bd2
authored
Jul 27, 2025
by
fzyzcjy
Committed by
GitHub
Jul 27, 2025
Browse files
Minor tool for comparison of benchmark results (#7974)
parent
ed0fdbf3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
222 additions
and
0 deletions
+222
-0
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+7
-0
benchmark/mmlu/bench_sglang.py
benchmark/mmlu/bench_sglang.py
+8
-0
python/sglang/srt/debug_utils/text_comparator.py
python/sglang/srt/debug_utils/text_comparator.py
+172
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+35
-0
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
62222bd2
...
@@ -10,6 +10,7 @@ import numpy as np
...
@@ -10,6 +10,7 @@ import numpy as np
from
sglang.api
import
set_default_backend
from
sglang.api
import
set_default_backend
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
dump_bench_raw_result
,
select_sglang_backend
,
select_sglang_backend
,
)
)
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
...
@@ -115,6 +116,12 @@ def main(args):
...
@@ -115,6 +116,12 @@ def main(args):
# Dump results
# Dump results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
dump_bench_raw_result
(
path
=
args
.
raw_result_file
,
states
=
states
,
preds
=
preds
,
labels
=
labels
,
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
value
=
{
...
...
benchmark/mmlu/bench_sglang.py
View file @
62222bd2
...
@@ -9,6 +9,7 @@ import tiktoken
...
@@ -9,6 +9,7 @@ import tiktoken
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
dump_bench_raw_result
,
select_sglang_backend
,
select_sglang_backend
,
)
)
...
@@ -142,6 +143,13 @@ def main(args):
...
@@ -142,6 +143,13 @@ def main(args):
assert
pt
==
len
(
cors
)
assert
pt
==
len
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
dump_bench_raw_result
(
path
=
args
.
raw_result_file
,
states
=
states
,
preds
=
preds
,
labels
=
labels
,
)
# Print results
# Print results
print
(
"Total latency: {:.3f}"
.
format
(
latency
))
print
(
"Total latency: {:.3f}"
.
format
(
latency
))
print
(
"Average accuracy: {:.3f}"
.
format
(
weighted_acc
))
print
(
"Average accuracy: {:.3f}"
.
format
(
weighted_acc
))
...
...
python/sglang/srt/debug_utils/text_comparator.py
0 → 100644
View file @
62222bd2
import
argparse
import
json
from
pathlib
import
Path
import
polars
as
pl
_DESCRIPTION
=
"""Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def
main
(
args
):
df_input
=
_transform_df_input
(
_compute_df_raw
(
args
))
assert
all
(
c
in
df_input
.
columns
for
c
in
[
"category"
,
"trial_index"
,
"prompt_id"
,
"prompt"
,
"output"
,
"correct"
]
)
df_meta
=
_compute_df_meta
(
df_input
)
df_correctness_per_trial
=
df_input
.
group_by
(
"category"
,
"trial_index"
,
maintain_order
=
True
).
agg
(
pl
.
col
(
"correct"
).
mean
())
df_correctness_delta
=
(
df_meta
.
group_by
(
"correctness_delta"
).
len
().
sort
(
"correctness_delta"
)
)
df_good_to_bad
=
df_meta
.
filter
(
pl
.
col
(
"correctness_delta"
)
<
0
)
df_bad_to_good
=
df_meta
.
filter
(
pl
.
col
(
"correctness_delta"
)
>
0
)
print
(
f
"Dump output to
{
args
.
output_path
}
"
)
Path
(
args
.
output_path
).
write_text
(
json
.
dumps
(
dict
(
df_meta
=
df_meta
.
to_dicts
(),
df_good_to_bad
=
df_good_to_bad
.
to_dicts
(),
df_bad_to_good
=
df_bad_to_good
.
to_dicts
(),
)
)
)
if
not
args
.
disable_print_details
:
with
pl
.
Config
(
fmt_str_lengths
=
10000
,
tbl_cols
=-
1
,
tbl_rows
=-
1
,
tbl_width_chars
=-
1
,
tbl_formatting
=
"UTF8_FULL"
,
):
print
(
"====== Correctness per trial ======"
)
print
(
df_correctness_per_trial
)
print
(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print
(
df_correctness_delta
)
for
name
,
df
in
[
(
"Good->Bad"
,
df_good_to_bad
),
(
"Bad->Good"
,
df_bad_to_good
),
]:
print
(
f
"====== Concrete Examples:
{
name
}
======"
)
print
(
df
)
def
_compute_df_raw
(
args
):
return
pl
.
concat
(
[
_read_df_raw
(
p
,
category
=
category
,
trial_index
=
i
)
for
category
,
paths
in
[
(
"baseline"
,
args
.
baseline_path
),
(
"target"
,
args
.
target_path
),
]
for
i
,
p
in
enumerate
(
paths
)
]
)
def
_read_df_raw
(
path
:
str
,
category
:
str
,
trial_index
:
int
):
return
pl
.
read_ndjson
(
path
).
with_columns
(
category
=
pl
.
lit
(
category
),
trial_index
=
trial_index
)
def
_transform_df_input
(
df
:
pl
.
DataFrame
):
if
"doc_id"
in
df
.
columns
:
print
(
"Transform mode: lm_eval"
)
filter_names
=
df
[
"filter"
].
unique
(
maintain_order
=
True
).
to_list
()
if
len
(
filter_names
)
>
1
:
filter_name
=
filter_names
[
0
]
print
(
f
"Choose
{
filter_name
=
}
among
{
filter_names
}
"
)
df
=
df
.
filter
(
pl
.
col
(
"filter"
)
==
filter_name
)
df
=
df
.
select
(
pl
.
col
(
"category"
),
pl
.
col
(
"trial_index"
),
prompt_id
=
pl
.
col
(
"doc_id"
),
prompt
=
pl
.
col
(
"arguments"
).
struct
.
field
(
"gen_args_0"
).
struct
.
field
(
"arg_0"
),
output
=
pl
.
col
(
"resps"
).
list
.
get
(
0
).
list
.
get
(
0
),
correct
=
pl
.
col
(
"exact_match"
).
cast
(
bool
),
)
return
df
elif
"prompt_id"
in
df
.
columns
:
print
(
"Transform mode: SGLang bench"
)
return
df
else
:
raise
Exception
(
f
"Unknown data:
{
df
.
columns
}
"
)
def
_compute_df_meta
(
df_input
:
pl
.
DataFrame
):
df_input
=
df_input
.
sort
(
"prompt_id"
,
"category"
,
"trial_index"
)
df_meta
=
pl
.
DataFrame
(
[
_handle_one_prompt
(
df_one_prompt
)
for
df_one_prompt
in
df_input
.
partition_by
(
"prompt_id"
,
maintain_order
=
True
)
]
)
df_meta
=
df_meta
.
with_columns
(
correctness_delta
=
pl
.
col
(
"correctness_target"
)
-
pl
.
col
(
"correctness_baseline"
),
)
df_meta
=
df_meta
.
sort
(
"correctness_delta"
,
"output_same_prefix_len"
)
return
df_meta
def
_handle_one_prompt
(
df_one_prompt
:
pl
.
DataFrame
):
assert
len
(
set
(
df_one_prompt
[
"prompt"
]))
==
1
df_baseline
=
df_one_prompt
.
filter
(
pl
.
col
(
"category"
)
==
"baseline"
)
df_target
=
df_one_prompt
.
filter
(
pl
.
col
(
"category"
)
==
"target"
)
outputs_baseline
=
df_baseline
[
"output"
].
to_list
()
outputs_target
=
df_target
[
"output"
].
to_list
()
output_same_prefix_len
=
max
(
_compute_str_prefix_len
(
output_baseline
,
output_target
)
for
output_baseline
in
outputs_baseline
for
output_target
in
outputs_target
)
return
dict
(
prompt_id
=
df_one_prompt
[
0
,
"prompt_id"
],
correctness_baseline
=
df_baseline
[
"correct"
].
mean
(),
correctness_target
=
df_target
[
"correct"
].
mean
(),
output_same_prefix_len
=
output_same_prefix_len
,
prompt
=
df_one_prompt
[
0
,
"prompt"
],
outputs_baseline
=
outputs_baseline
,
outputs_target
=
outputs_target
,
)
def
_compute_str_prefix_len
(
a
:
str
,
b
:
str
)
->
int
:
min_len
=
min
(
len
(
a
),
len
(
b
))
for
i
in
range
(
min_len
):
if
a
[
i
]
!=
b
[
i
]:
return
i
return
min_len
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
_DESCRIPTION
)
parser
.
add_argument
(
"--baseline-path"
,
type
=
str
,
nargs
=
"+"
)
parser
.
add_argument
(
"--target-path"
,
type
=
str
,
nargs
=
"+"
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
"/tmp/text_comparator_output.json"
)
parser
.
add_argument
(
"--disable-print-details"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
python/sglang/test/test_utils.py
View file @
62222bd2
...
@@ -15,6 +15,7 @@ import unittest
...
@@ -15,6 +15,7 @@ import unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
...
@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
...
@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.interpreter
import
ProgramState
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
get_device
,
get_device
,
...
@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
...
@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help
=
"Device type (auto/cuda/rocm/cpu). Auto will detect available platforms"
,
help
=
"Device type (auto/cuda/rocm/cpu). Auto will detect available platforms"
,
)
)
parser
.
add_argument
(
"--result-file"
,
type
=
str
,
default
=
"result.jsonl"
)
parser
.
add_argument
(
"--result-file"
,
type
=
str
,
default
=
"result.jsonl"
)
parser
.
add_argument
(
"--raw-result-file"
,
type
=
str
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
...
@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
lambda
:
super
(
CustomTestCase
,
self
).
_callTestMethod
(
method
),
lambda
:
super
(
CustomTestCase
,
self
).
_callTestMethod
(
method
),
max_retry
=
max_retry
,
max_retry
=
max_retry
,
)
)
def
dump_bench_raw_result
(
path
:
str
,
states
,
preds
,
labels
,
):
if
not
path
:
return
rows
=
[]
for
i
in
range
(
len
(
states
)):
state
=
states
[
i
]
output
=
state
[
"answer"
]
prompt
=
_ensure_remove_suffix
(
state
.
text
(),
output
)
rows
.
append
(
dict
(
prompt_id
=
i
,
prompt
=
prompt
,
output
=
output
,
correct
=
bool
(
preds
[
i
]
==
labels
[
i
]),
)
)
print
(
f
"BenchRawResultDumper save results to
{
path
}
"
)
Path
(
path
).
write_text
(
"
\n
"
.
join
(
json
.
dumps
(
row
)
for
row
in
rows
))
def
_ensure_remove_suffix
(
text
:
str
,
suffix
:
str
):
assert
text
.
endswith
(
suffix
)
return
text
.
removesuffix
(
suffix
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment