Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
861c5c27
Unverified
Commit
861c5c27
authored
Mar 11, 2025
by
Stella Biderman
Committed by
GitHub
Mar 11, 2025
Browse files
Create utils.py
parent
09228840
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
230 additions
and
0 deletions
+230
-0
lm_eval/tasks/aime24/utils.py
lm_eval/tasks/aime24/utils.py
+230
-0
No files found.
lm_eval/tasks/aime24/utils.py
0 → 100644
View file @
861c5c27
from
typing
import
Dict
,
List
import
datasets
def
process_docs
(
dataset
:
datasets
.
Dataset
)
->
datasets
.
Dataset
:
def
_process_doc
(
doc
:
dict
)
->
dict
:
out_doc
=
{
"problem"
:
doc
[
"problem"
],
"solution"
:
doc
[
"solution"
],
"answer"
:
remove_boxed
(
last_boxed_only_string
(
doc
[
"solution"
])),
}
return
out_doc
return
dataset
.
map
(
_process_doc
)
def
process_results
(
doc
:
dict
,
results
:
List
[
str
])
->
Dict
[
str
,
int
]:
retval
=
0
indices
=
[
pos
for
pos
,
char
in
enumerate
(
results
[
0
])
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
results
[
0
]
else
:
answer
=
results
[
0
][
indices
[
0
]
+
1
:
indices
[
-
1
]]
if
is_equiv
(
answer
,
remove_boxed
(
last_boxed_only_string
(
doc
[
"solution"
]))):
retval
=
1
results
=
{
"exact_match"
:
retval
,
}
return
results
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def
is_equiv
(
str1
,
str2
,
verbose
=
False
):
if
str1
is
None
and
str2
is
None
:
print
(
"WARNING: Both None"
)
return
True
if
str1
is
None
or
str2
is
None
:
return
False
try
:
ss1
=
strip_string
(
str1
)
ss2
=
strip_string
(
str2
)
if
verbose
:
print
(
ss1
,
ss2
)
return
ss1
==
ss2
except
Exception
:
return
str1
==
str2
def
remove_boxed
(
s
):
if
"
\\
boxed "
in
s
:
left
=
"
\\
boxed "
assert
s
[:
len
(
left
)]
==
left
return
s
[
len
(
left
)
:]
left
=
"
\\
boxed{"
assert
s
[:
len
(
left
)]
==
left
assert
s
[
-
1
]
==
"}"
return
s
[
len
(
left
)
:
-
1
]
def
last_boxed_only_string
(
string
):
idx
=
string
.
rfind
(
"
\\
boxed"
)
if
"
\\
boxed "
in
string
:
return
"
\\
boxed "
+
string
.
split
(
"
\\
boxed "
)[
-
1
].
split
(
"$"
)[
0
]
if
idx
<
0
:
idx
=
string
.
rfind
(
"
\\
fbox"
)
if
idx
<
0
:
return
None
i
=
idx
right_brace_idx
=
None
num_left_braces_open
=
0
while
i
<
len
(
string
):
if
string
[
i
]
==
"{"
:
num_left_braces_open
+=
1
if
string
[
i
]
==
"}"
:
num_left_braces_open
-=
1
if
num_left_braces_open
==
0
:
right_brace_idx
=
i
break
i
+=
1
if
right_brace_idx
is
None
:
retval
=
None
else
:
retval
=
string
[
idx
:
right_brace_idx
+
1
]
return
retval
def
fix_fracs
(
string
):
substrs
=
string
.
split
(
"
\\
frac"
)
new_str
=
substrs
[
0
]
if
len
(
substrs
)
>
1
:
substrs
=
substrs
[
1
:]
for
substr
in
substrs
:
new_str
+=
"
\\
frac"
if
substr
[
0
]
==
"{"
:
new_str
+=
substr
else
:
try
:
assert
len
(
substr
)
>=
2
except
AssertionError
:
return
string
a
=
substr
[
0
]
b
=
substr
[
1
]
if
b
!=
"{"
:
if
len
(
substr
)
>
2
:
post_substr
=
substr
[
2
:]
new_str
+=
"{"
+
a
+
"}{"
+
b
+
"}"
+
post_substr
else
:
new_str
+=
"{"
+
a
+
"}{"
+
b
+
"}"
else
:
if
len
(
substr
)
>
2
:
post_substr
=
substr
[
2
:]
new_str
+=
"{"
+
a
+
"}"
+
b
+
post_substr
else
:
new_str
+=
"{"
+
a
+
"}"
+
b
string
=
new_str
return
string
def
fix_a_slash_b
(
string
):
if
len
(
string
.
split
(
"/"
))
!=
2
:
return
string
a
=
string
.
split
(
"/"
)[
0
]
b
=
string
.
split
(
"/"
)[
1
]
try
:
a
=
int
(
a
)
b
=
int
(
b
)
assert
string
==
"{}/{}"
.
format
(
a
,
b
)
new_string
=
"
\\
frac{"
+
str
(
a
)
+
"}{"
+
str
(
b
)
+
"}"
return
new_string
except
AssertionError
:
return
string
def
remove_right_units
(
string
):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if
"
\\
text{ "
in
string
:
splits
=
string
.
split
(
"
\\
text{ "
)
assert
len
(
splits
)
==
2
return
splits
[
0
]
else
:
return
string
def
fix_sqrt
(
string
):
if
"
\\
sqrt"
not
in
string
:
return
string
splits
=
string
.
split
(
"
\\
sqrt"
)
new_string
=
splits
[
0
]
for
split
in
splits
[
1
:]:
if
split
[
0
]
!=
"{"
:
a
=
split
[
0
]
new_substr
=
"
\\
sqrt{"
+
a
+
"}"
+
split
[
1
:]
else
:
new_substr
=
"
\\
sqrt"
+
split
new_string
+=
new_substr
return
new_string
def
strip_string
(
string
):
# linebreaks
string
=
string
.
replace
(
"
\n
"
,
""
)
# remove inverse spaces
string
=
string
.
replace
(
"
\\
!"
,
""
)
# replace \\ with \
string
=
string
.
replace
(
"
\\\\
"
,
"
\\
"
)
# replace tfrac and dfrac with frac
string
=
string
.
replace
(
"tfrac"
,
"frac"
)
string
=
string
.
replace
(
"dfrac"
,
"frac"
)
# remove \left and \right
string
=
string
.
replace
(
"
\\
left"
,
""
)
string
=
string
.
replace
(
"
\\
right"
,
""
)
# Remove circ (degrees)
string
=
string
.
replace
(
"^{
\\
circ}"
,
""
)
string
=
string
.
replace
(
"^
\\
circ"
,
""
)
# remove dollar signs
string
=
string
.
replace
(
"
\\
$"
,
""
)
# remove units (on the right)
string
=
remove_right_units
(
string
)
# remove percentage
string
=
string
.
replace
(
"
\\
%"
,
""
)
string
=
string
.
replace
(
"\%"
,
""
)
# noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string
=
string
.
replace
(
" ."
,
" 0."
)
string
=
string
.
replace
(
"{."
,
"{0."
)
# if empty, return empty string
if
len
(
string
)
==
0
:
return
string
if
string
[
0
]
==
"."
:
string
=
"0"
+
string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if
len
(
string
.
split
(
"="
))
==
2
:
if
len
(
string
.
split
(
"="
)[
0
])
<=
2
:
string
=
string
.
split
(
"="
)[
1
]
# fix sqrt3 --> sqrt{3}
string
=
fix_sqrt
(
string
)
# remove spaces
string
=
string
.
replace
(
" "
,
""
)
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string
=
fix_fracs
(
string
)
# manually change 0.5 --> \frac{1}{2}
if
string
==
"0.5"
:
string
=
"
\\
frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string
=
fix_a_slash_b
(
string
)
return
string
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment