Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
95daf009
Unverified
Commit
95daf009
authored
Mar 25, 2021
by
Leo Gao
Committed by
GitHub
Mar 25, 2021
Browse files
Merge branch 'master' into master
parents
e2112196
6a78fdaf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
354 additions
and
3 deletions
+354
-3
README.md
README.md
+7
-0
lm_eval/base.py
lm_eval/base.py
+3
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+10
-0
lm_eval/tasks/logiqa.py
lm_eval/tasks/logiqa.py
+19
-2
lm_eval/tasks/math.py
lm_eval/tasks/math.py
+315
-0
No files found.
README.md
View file @
95daf009
...
...
@@ -59,6 +59,13 @@ The goal of this project is to build a set of tools for evaluating LMs on typica
|ethics_utilitarianism_original|✓ |✓ |✓ |acc |
|ethics_utilitarianism |✓ |✓ |✓ |acc |
|ethics_virtue |✓ |✓ |✓ |acc, em |
|math_algebra |✓ | |✓ |acc |
|math_counting_and_prob |✓ | |✓ |acc |
|math_geometry |✓ | |✓ |acc |
|math_intermediate_algebra |✓ | |✓ |acc |
|math_num_theory |✓ | |✓ |acc |
|math_prealgebra |✓ | |✓ |acc |
|math_precalc |✓ | |✓ |acc |
|arithmetic_2da | |✓ | |acc |
|arithmetic_2ds | |✓ | |acc |
|arithmetic_3da | |✓ | |acc |
...
...
lm_eval/base.py
View file @
95daf009
...
...
@@ -117,7 +117,9 @@ class Task(abc.ABC):
def
fewshot_examples
(
self
,
k
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
return
random
.
sample
(
self
.
_training_docs
,
k
)
rnd
=
random
.
Random
()
rnd
.
seed
(
42
)
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
):
...
...
lm_eval/tasks/__init__.py
View file @
95daf009
...
...
@@ -34,6 +34,7 @@ from . import drop
from
.
import
unscramble
from
.
import
logiqa
from
.
import
hendrycks_test
from
.
import
math
########################################
# Translation tasks
...
...
@@ -127,6 +128,15 @@ TASK_REGISTRY = {
"ethics_utilitarianism"
:
ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
ethics
.
EthicsVirtue
,
# math
"math_algebra"
:
math
.
MathAlgebra
,
"math_counting_and_prob"
:
math
.
MathCountingAndProbability
,
"math_geometry"
:
math
.
MathGeometry
,
"math_intermediate_algebra"
:
math
.
MathIntermediateAlgebra
,
"math_num_theory"
:
math
.
MathNumberTheory
,
"math_prealgebra"
:
math
.
MathPrealgebra
,
"math_precalc"
:
math
.
MathPrecalculus
,
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
lm_eval/tasks/logiqa.py
View file @
95daf009
...
...
@@ -30,10 +30,27 @@ class LogiQA(MultipleChoiceTask):
return
True
def
_convert_standard
(
self
,
doc
):
def
format_example
(
doc
,
choices
):
"""
Passage: <passage>
Question: <question>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt
=
"Passage: "
+
doc
[
"passage"
]
+
"
\n
"
prompt
+=
"Question: "
+
doc
[
"question"
]
+
"
\n
"
for
choice
,
option
in
zip
(
choices
,
doc
[
"options"
]):
prompt
+=
f
"
{
choice
.
upper
()
}
.
{
option
}
\n
"
prompt
+=
"Answer:"
return
prompt
choices
=
[
'a'
,
'b'
,
'c'
,
'd'
]
return
{
"query"
:
"Passage: "
+
doc
[
"passage"
]
+
"
\n
Question: "
+
doc
[
"question"
]
+
"
\n
Answer:"
,
"query"
:
format_example
(
doc
,
choices
)
,
"choices"
:
doc
[
"options"
],
"gold"
:
[
"a"
,
"b"
,
"c"
,
"d"
]
.
index
(
doc
[
"answerKey"
])
"gold"
:
choices
.
index
(
doc
[
"answerKey"
])
}
def
_load_docs
(
self
,
filename
):
...
...
lm_eval/tasks/math.py
0 → 100644
View file @
95daf009
import
json
import
random
from
lm_eval.utils
import
sh
from
lm_eval.metrics
import
mean
from
lm_eval.base
import
Task
,
rf
from
pathlib
import
Path
import
abc
class
Math
(
Task
):
"""
This dataset is based on the following paper:
https://arxiv.org/abs/2103.03874
"""
DATASET_PATH
=
Path
(
'data/MATH'
)
def
download
(
self
):
if
not
self
.
DATASET_PATH
.
exists
():
sh
(
f
"""
mkdir -p
{
self
.
DATASET_PATH
}
wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/
tar -xvf
{
self
.
DATASET_PATH
}
.tar.gz -C data/
rm
{
self
.
DATASET_PATH
}
.tar.gz
"""
)
@
abc
.
abstractmethod
def
get_file_info
(
self
):
"""returns directory name"""
pass
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
_load_docs
(
self
,
path
):
for
file
in
path
.
iterdir
():
with
open
(
file
)
as
f
:
doc
=
json
.
load
(
f
)
doc
[
"answer"
]
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
yield
doc
def
training_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"train"
/
self
.
get_file_info
())
def
validation_docs
(
self
):
return
NotImplemented
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"test"
/
self
.
get_file_info
())
def
fewshot_description
(
self
):
return
"Given a mathematics problem, determine the answer. Simplify your answer as much as possible."
def
doc_to_text
(
self
,
doc
):
return
"Problem: "
+
doc
[
"problem"
]
+
"
\n
Answer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"answer"
]
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
def
process_results
(
self
,
doc
,
results
):
retval
=
0
indices
=
[
pos
for
pos
,
char
in
enumerate
(
results
[
0
])
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
results
[
0
]
else
:
answer
=
results
[
0
][
indices
[
0
]
+
1
:
indices
[
-
1
]]
if
self
.
is_equiv
(
answer
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))):
retval
=
1
return
{
"acc"
:
retval
}
def
aggregation
(
self
):
return
{
'acc'
:
mean
}
def
higher_is_better
(
self
):
return
{
'acc'
:
True
}
def
is_equiv
(
self
,
str1
,
str2
,
verbose
=
False
):
if
str1
is
None
and
str2
is
None
:
print
(
"WARNING: Both None"
)
return
True
if
str1
is
None
or
str2
is
None
:
return
False
try
:
ss1
=
self
.
strip_string
(
str1
)
ss2
=
self
.
strip_string
(
str2
)
if
verbose
:
print
(
ss1
,
ss2
)
return
ss1
==
ss2
except
:
return
str1
==
str2
def
remove_boxed
(
self
,
s
):
left
=
"
\\
boxed{"
try
:
assert
s
[:
len
(
left
)]
==
left
assert
s
[
-
1
]
==
"}"
return
s
[
len
(
left
):
-
1
]
except
AssertionError
:
return
None
def
last_boxed_only_string
(
self
,
string
):
idx
=
string
.
rfind
(
"
\\
boxed"
)
if
idx
<
0
:
idx
=
string
.
rfind
(
"
\\
fbox"
)
if
idx
<
0
:
return
None
i
=
idx
right_brace_idx
=
None
num_left_braces_open
=
0
while
i
<
len
(
string
):
if
string
[
i
]
==
"{"
:
num_left_braces_open
+=
1
if
string
[
i
]
==
"}"
:
num_left_braces_open
-=
1
if
num_left_braces_open
==
0
:
right_brace_idx
=
i
break
i
+=
1
if
right_brace_idx
is
None
:
retval
=
None
else
:
retval
=
string
[
idx
:
right_brace_idx
+
1
]
return
retval
def
fix_fracs
(
self
,
string
):
substrs
=
string
.
split
(
"
\\
frac"
)
new_str
=
substrs
[
0
]
if
len
(
substrs
)
>
1
:
substrs
=
substrs
[
1
:]
for
substr
in
substrs
:
new_str
+=
"
\\
frac"
if
substr
[
0
]
==
"{"
:
new_str
+=
substr
else
:
try
:
assert
len
(
substr
)
>=
2
except
AssertionError
:
return
string
a
=
substr
[
0
]
b
=
substr
[
1
]
if
b
!=
"{"
:
if
len
(
substr
)
>
2
:
post_substr
=
substr
[
2
:]
new_str
+=
"{"
+
a
+
"}{"
+
b
+
"}"
+
post_substr
else
:
new_str
+=
"{"
+
a
+
"}{"
+
b
+
"}"
else
:
if
len
(
substr
)
>
2
:
post_substr
=
substr
[
2
:]
new_str
+=
"{"
+
a
+
"}"
+
b
+
post_substr
else
:
new_str
+=
"{"
+
a
+
"}"
+
b
string
=
new_str
return
string
def
fix_a_slash_b
(
self
,
string
):
if
len
(
string
.
split
(
"/"
))
!=
2
:
return
string
a
=
string
.
split
(
"/"
)[
0
]
b
=
string
.
split
(
"/"
)[
1
]
try
:
a
=
int
(
a
)
b
=
int
(
b
)
assert
string
==
"{}/{}"
.
format
(
a
,
b
)
new_string
=
"
\\
frac{"
+
str
(
a
)
+
"}{"
+
str
(
b
)
+
"}"
return
new_string
except
AssertionError
:
return
string
def
remove_right_units
(
self
,
string
):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if
"
\\
text{ "
in
string
:
splits
=
string
.
split
(
"
\\
text{ "
)
assert
len
(
splits
)
==
2
return
splits
[
0
]
else
:
return
string
def
fix_sqrt
(
self
,
string
):
if
"
\\
sqrt"
not
in
string
:
return
string
splits
=
string
.
split
(
"
\\
sqrt"
)
new_string
=
splits
[
0
]
for
split
in
splits
[
1
:]:
if
split
[
0
]
!=
"{"
:
a
=
split
[
0
]
new_substr
=
"
\\
sqrt{"
+
a
+
"}"
+
split
[
1
:]
else
:
new_substr
=
"
\\
sqrt"
+
split
new_string
+=
new_substr
return
new_string
class
NotEqual
:
def
__eq__
(
self
,
other
):
return
False
def
strip_string
(
self
,
string
):
# linebreaks
string
=
string
.
replace
(
"
\n
"
,
""
)
# remove inverse spaces
string
=
string
.
replace
(
"
\\
!"
,
""
)
# replace \\ with \
string
=
string
.
replace
(
"
\\\\
"
,
"
\\
"
)
# replace tfrac and dfrac with frac
string
=
string
.
replace
(
"tfrac"
,
"frac"
)
string
=
string
.
replace
(
"dfrac"
,
"frac"
)
# remove \left and \right
string
=
string
.
replace
(
"
\\
left"
,
""
)
string
=
string
.
replace
(
"
\\
right"
,
""
)
# Remove circ (degrees)
string
=
string
.
replace
(
"^{
\\
circ}"
,
""
)
string
=
string
.
replace
(
"^
\\
circ"
,
""
)
# remove dollar signs
string
=
string
.
replace
(
"
\\
$"
,
""
)
# remove units (on the right)
string
=
self
.
remove_right_units
(
string
)
# remove percentage
string
=
string
.
replace
(
"
\\
%"
,
""
)
string
=
string
.
replace
(
"\%"
,
""
)
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string
=
string
.
replace
(
" ."
,
" 0."
)
string
=
string
.
replace
(
"{."
,
"{0."
)
# if empty, return empty string
if
len
(
string
)
==
0
:
return
string
if
string
[
0
]
==
"."
:
string
=
"0"
+
string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if
len
(
string
.
split
(
"="
))
==
2
:
if
len
(
string
.
split
(
"="
)[
0
])
<=
2
:
string
=
string
.
split
(
"="
)[
1
]
# fix sqrt3 --> sqrt{3}
string
=
self
.
fix_sqrt
(
string
)
# remove spaces
string
=
string
.
replace
(
" "
,
""
)
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string
=
self
.
fix_fracs
(
string
)
# manually change 0.5 --> \frac{1}{2}
if
string
==
"0.5"
:
string
=
"
\\
frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string
=
self
.
fix_a_slash_b
(
string
)
return
string
class
MathAlgebra
(
Math
):
def
get_file_info
(
self
):
return
'algebra'
class
MathCountingAndProbability
(
Math
):
def
get_file_info
(
self
):
return
'counting_and_probability'
class
MathGeometry
(
Math
):
def
get_file_info
(
self
):
return
'geometry'
class
MathIntermediateAlgebra
(
Math
):
def
get_file_info
(
self
):
return
'intermediate_algebra'
class
MathNumberTheory
(
Math
):
def
get_file_info
(
self
):
return
'number_theory'
class
MathPrealgebra
(
Math
):
def
get_file_info
(
self
):
return
'prealgebra'
class
MathPrecalculus
(
Math
):
def
get_file_info
(
self
):
return
'precalculus'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment