Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bb121214
Unverified
Commit
bb121214
authored
Feb 19, 2025
by
simveit
Committed by
GitHub
Feb 20, 2025
Browse files
Variance measure for reasoning benchmark (#3677)
parent
55de40f7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
10 deletions
+26
-10
benchmark/reasoning_benchmark/bench_sglang.py
benchmark/reasoning_benchmark/bench_sglang.py
+26
-10
No files found.
benchmark/reasoning_benchmark/bench_sglang.py
View file @
bb121214
...
...
@@ -4,6 +4,7 @@ import time
import
answer_extraction
import
eval_utils
import
numpy
as
np
from
datasets
import
load_dataset
import
sglang
as
sgl
...
...
@@ -61,26 +62,40 @@ def main(args):
)
latency
=
time
.
time
()
-
tic
# Extract
answers
correct
=
0
# Extract
results and record outcomes in a list.
outcomes
=
[]
for
i
,
state
in
enumerate
(
states
):
try
:
pred_answer
=
answer_extraction
.
extract_math_answer
(
questions
[
i
][
"question"
],
state
[
"answer"
],
"limo"
)
gt_answer
=
str
(
answers
[
i
][
"answer"
])
# Use last answer if multiple were extracted
pred_answer
=
(
pred_answer
[
-
1
]
if
isinstance
(
pred_answer
,
list
)
else
pred_answer
)
correct
+
=
1
if
eval_utils
.
math_equal
(
pred_answer
,
gt_answer
)
else
0
is_
correct
=
1
if
eval_utils
.
math_equal
(
pred_answer
,
gt_answer
)
else
0
except
Exception
as
e
:
print
(
f
"Error extracting answer:
{
e
}
"
)
pass
# Calculate accuracy
accuracy
=
correct
/
len
(
questions
)
print
(
f
"Accuracy:
{
accuracy
}
"
)
is_correct
=
0
outcomes
.
append
(
is_correct
)
# Calculate overall accuracy using numpy
overall_accuracy
=
np
.
mean
(
outcomes
)
print
(
f
"Overall Accuracy:
{
overall_accuracy
}
"
)
# Calculate mean standard error over questions if num_tries >= 2
if
args
.
num_tries
>
1
:
outcomes_np
=
np
.
array
(
outcomes
).
reshape
(
-
1
,
args
.
num_tries
)
# Using sample standard deviation with ddof=1
std_per_question
=
np
.
std
(
outcomes_np
,
axis
=
1
,
ddof
=
1
)
# Compute the standard error for each question: std / sqrt(num_tries)
se_per_question
=
std_per_question
/
np
.
sqrt
(
args
.
num_tries
)
mean_se
=
se_per_question
.
mean
()
print
(
f
"Mean Standard Error of Accuracy across questions:
{
mean_se
}
"
)
else
:
mean_se
=
None
print
(
"Not enough samples per question to compute standard error."
)
# Calculate output throughput
num_output_tokens
=
sum
(
...
...
@@ -98,7 +113,8 @@ def main(args):
"task"
:
"limo"
,
"backend"
:
args
.
backend
,
"latency"
:
round
(
latency
,
3
),
"accuracy"
:
round
(
accuracy
,
3
),
"overall_accuracy"
:
round
(
overall_accuracy
,
3
),
"mean_se_accuracy"
:
round
(
mean_se
,
3
)
if
mean_se
is
not
None
else
None
,
"num_requests"
:
len
(
questions
),
"other"
:
{
"num_questions"
:
len
(
questions
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment