Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
bf19a7a1
Unverified
Commit
bf19a7a1
authored
Feb 27, 2022
by
Leo Gao
Committed by
GitHub
Feb 27, 2022
Browse files
Merge branch 'master' into researcher2
parents
0f283a9c
663b781b
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
707 additions
and
271 deletions
+707
-271
README.md
README.md
+273
-252
lm_eval/evaluator.py
lm_eval/evaluator.py
+10
-2
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+7
-7
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+2
-0
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+5
-0
lm_eval/tasks/gsm8k.py
lm_eval/tasks/gsm8k.py
+139
-0
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+8
-8
lm_eval/tasks/qasper.py
lm_eval/tasks/qasper.py
+217
-0
lm_eval/utils.py
lm_eval/utils.py
+33
-0
main.py
main.py
+3
-1
setup.py
setup.py
+1
-1
tests/testdata/gsm8k-v0-greedy_until
tests/testdata/gsm8k-v0-greedy_until
+1
-0
tests/testdata/gsm8k-v0-res.json
tests/testdata/gsm8k-v0-res.json
+1
-0
tests/testdata/math_algebra-v1-greedy_until
tests/testdata/math_algebra-v1-greedy_until
+1
-0
tests/testdata/math_algebra-v1-res.json
tests/testdata/math_algebra-v1-res.json
+1
-0
tests/testdata/math_counting_and_prob-v1-greedy_until
tests/testdata/math_counting_and_prob-v1-greedy_until
+1
-0
tests/testdata/math_counting_and_prob-v1-res.json
tests/testdata/math_counting_and_prob-v1-res.json
+1
-0
tests/testdata/math_geometry-v1-greedy_until
tests/testdata/math_geometry-v1-greedy_until
+1
-0
tests/testdata/math_geometry-v1-res.json
tests/testdata/math_geometry-v1-res.json
+1
-0
tests/testdata/math_intermediate_algebra-v1-greedy_until
tests/testdata/math_intermediate_algebra-v1-greedy_until
+1
-0
No files found.
README.md
View file @
bf19a7a1
This diff is collapsed.
Click to expand it.
lm_eval/evaluator.py
View file @
bf19a7a1
import
collections
import
collections
import
itertools
import
itertools
import
pathlib
import
random
import
random
import
lm_eval.metrics
import
lm_eval.metrics
import
lm_eval.models
import
lm_eval.models
...
@@ -7,13 +8,15 @@ import lm_eval.tasks
...
@@ -7,13 +8,15 @@ import lm_eval.tasks
import
lm_eval.base
import
lm_eval.base
import
lm_eval.decontamination
import
lm_eval.decontamination
import
numpy
as
np
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
decontamination_ngrams_path
=
None
):
description_dict
=
None
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
:param model: Union[str, LM]
...
@@ -37,6 +40,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -37,6 +40,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
:return
Dictionary of results
Dictionary of results
"""
"""
...
@@ -61,6 +66,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -61,6 +66,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
results
=
evaluate
(
results
=
evaluate
(
lm
=
lm
,
lm
=
lm
,
task_dict
=
task_dict
,
task_dict
=
task_dict
,
...
...
lm_eval/models/dummy.py
View file @
bf19a7a1
...
@@ -7,30 +7,30 @@ class DummyLM(LM):
...
@@ -7,30 +7,30 @@ class DummyLM(LM):
pass
pass
@
classmethod
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
def
create_from_arg_string
(
cls
,
arg_string
,
additional_config
=
None
):
return
cls
()
return
cls
()
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
for
_
in
requests
:
for
_
in
requests
:
res
.
append
((
-
random
.
random
(),
False
))
res
.
append
((
-
random
.
random
(),
False
))
return
res
return
res
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
res
=
[]
res
=
[]
for
ctx
,
_
in
requests
:
for
ctx
,
_
in
requests
:
res
.
append
(
"lol"
)
res
.
append
(
"lol"
)
assert
ctx
.
strip
()
!=
''
assert
ctx
.
strip
()
!=
""
return
res
return
res
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
):
res
=
[]
res
=
[]
for
_
in
requests
:
for
_
in
requests
:
res
.
append
(
-
random
.
random
())
res
.
append
(
-
random
.
random
())
return
res
return
res
\ No newline at end of file
lm_eval/models/gpt3.py
View file @
bf19a7a1
...
@@ -46,6 +46,8 @@ def oa_completion(**kwargs):
...
@@ -46,6 +46,8 @@ def oa_completion(**kwargs):
try
:
try
:
return
openai
.
Completion
.
create
(
**
kwargs
)
return
openai
.
Completion
.
create
(
**
kwargs
)
except
openai
.
error
.
OpenAIError
:
except
openai
.
error
.
OpenAIError
:
import
traceback
traceback
.
print_exc
()
time
.
sleep
(
backoff_time
)
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
backoff_time
*=
1.5
...
...
lm_eval/tasks/__init__.py
View file @
bf19a7a1
...
@@ -29,6 +29,7 @@ from . import triviaqa
...
@@ -29,6 +29,7 @@ from . import triviaqa
from
.
import
pubmedqa
from
.
import
pubmedqa
from
.
import
sciq
from
.
import
sciq
from
.
import
webqs
from
.
import
webqs
from
.
import
qasper
from
.
import
qa4mre
from
.
import
qa4mre
from
.
import
translation
from
.
import
translation
from
.
import
headqa
from
.
import
headqa
...
@@ -48,6 +49,7 @@ from . import mutual
...
@@ -48,6 +49,7 @@ from . import mutual
from
.
import
truthfulqa
from
.
import
truthfulqa
from
.
import
blimp
from
.
import
blimp
from
.
import
asdiv
from
.
import
asdiv
from
.
import
gsm8k
########################################
########################################
# Translation tasks
# Translation tasks
...
@@ -121,6 +123,8 @@ TASK_REGISTRY = {
...
@@ -121,6 +123,8 @@ TASK_REGISTRY = {
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
...
@@ -170,6 +174,7 @@ TASK_REGISTRY = {
...
@@ -170,6 +174,7 @@ TASK_REGISTRY = {
"math_prealgebra"
:
hendrycks_math
.
MathPrealgebra
,
"math_prealgebra"
:
hendrycks_math
.
MathPrealgebra
,
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
...
...
lm_eval/tasks/gsm8k.py
0 → 100644
View file @
bf19a7a1
"""
"Training Verifiers to Solve Math Word Problems"
https://arxiv.org/abs/2110.14168
@misc{cobbe2021training,
title={Training Verifiers to Solve Math Word Problems},
author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
year={2021},
eprint={2110.14168},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
NOTE: See the official implementation of the task:
https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
for how to make use of the dataset's calculator annotations in your language
model's sample/generation function.
"""
import
json
import
re
from
best_download
import
download_file
from
pathlib
import
Path
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
ANS_RE
=
re
.
compile
(
r
"#### (\-?[0-9\.\,]+)"
)
INVALID_ANS
=
"[invalid]"
class
GradeSchoolMath8K
(
Task
):
VERSION
=
0
DATASET_PATH
=
Path
(
'data/gsm8k'
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
base_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits
=
[
{
"name"
:
"train"
,
"checksum"
:
"17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"
},
{
"name"
:
"test"
,
"checksum"
:
"3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"
},
]
for
split
in
splits
:
file
=
self
.
DATASET_PATH
/
f
"
{
split
[
'name'
]
}
.jsonl"
download_file
(
f
"
{
base_url
}
/
{
split
[
'name'
]
}
.jsonl"
,
str
(
file
),
split
[
"checksum"
])
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
_load_docs
(
self
,
file
):
return
(
json
.
loads
(
line
)
for
line
in
open
(
file
).
read
().
splitlines
())
def
training_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"train.jsonl"
)
def
validation_docs
(
self
):
raise
NotImplementedError
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"test.jsonl"
)
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'answer'
]
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
completion
=
rf
.
greedy_until
(
ctx
,
[
'
\n
'
])
return
completion
def
_extract_answer
(
self
,
completion
):
match
=
ANS_RE
.
search
(
completion
)
if
match
:
match_str
=
match
.
group
(
1
).
strip
()
match_str
=
match_str
.
replace
(
","
,
""
)
return
match_str
else
:
return
INVALID_ANS
def
_is_correct
(
self
,
completion
,
answer
):
gold
=
self
.
_extract_answer
(
answer
)
assert
gold
!=
INVALID_ANS
,
"No ground truth answer found in the document."
return
self
.
_extract_answer
(
completion
)
==
gold
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
completion
=
results
[
0
]
answer
=
doc
[
"answer"
]
return
{
"acc"
:
self
.
_is_correct
(
completion
,
answer
)
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
lm_eval/tasks/hendrycks_math.py
View file @
bf19a7a1
...
@@ -18,7 +18,7 @@ class Math(Task):
...
@@ -18,7 +18,7 @@ class Math(Task):
def
download
(
self
):
def
download
(
self
):
if
not
(
self
.
DATASET_PATH
/
'test'
).
exists
()
or
not
(
self
.
DATASET_PATH
/
'done'
).
exists
():
if
not
(
self
.
DATASET_PATH
/
'test'
).
exists
()
or
not
(
self
.
DATASET_PATH
/
'done'
).
exists
():
sh
(
f
"mkdir -p
{
self
.
DATASET_PATH
}
"
)
sh
(
f
"mkdir -p
{
self
.
DATASET_PATH
}
"
)
download_file
(
"https://people.eecs.berkeley.edu/~hendrycks/MATH.tar"
,
local_file
=
f
"
{
self
.
DATASET_PATH
}
.tar"
,
expected_checksum
=
"0
1256fd7cd5430596fdf07e6e6a5827111b5235b7ffed679c662a12f898932da
"
)
download_file
(
"https://people.eecs.berkeley.edu/~hendrycks/MATH.tar"
,
local_file
=
f
"
{
self
.
DATASET_PATH
}
.tar"
,
expected_checksum
=
"0
fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac
"
)
sh
(
f
"""
sh
(
f
"""
tar -xf
{
self
.
DATASET_PATH
}
.tar -C data/ && touch
{
self
.
DATASET_PATH
/
'done'
}
tar -xf
{
self
.
DATASET_PATH
}
.tar -C data/ && touch
{
self
.
DATASET_PATH
/
'done'
}
rm
{
self
.
DATASET_PATH
}
.tar
rm
{
self
.
DATASET_PATH
}
.tar
...
@@ -291,42 +291,42 @@ class Math(Task):
...
@@ -291,42 +291,42 @@ class Math(Task):
class
MathAlgebra
(
Math
):
class
MathAlgebra
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'algebra'
return
'algebra'
class
MathCountingAndProbability
(
Math
):
class
MathCountingAndProbability
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'counting_and_probability'
return
'counting_and_probability'
class
MathGeometry
(
Math
):
class
MathGeometry
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'geometry'
return
'geometry'
class
MathIntermediateAlgebra
(
Math
):
class
MathIntermediateAlgebra
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'intermediate_algebra'
return
'intermediate_algebra'
class
MathNumberTheory
(
Math
):
class
MathNumberTheory
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'number_theory'
return
'number_theory'
class
MathPrealgebra
(
Math
):
class
MathPrealgebra
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'prealgebra'
return
'prealgebra'
class
MathPrecalculus
(
Math
):
class
MathPrecalculus
(
Math
):
VERSION
=
0
VERSION
=
1
def
get_file_info
(
self
):
def
get_file_info
(
self
):
return
'precalculus'
return
'precalculus'
lm_eval/tasks/qasper.py
0 → 100644
View file @
bf19a7a1
"""
A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
https://arxiv.org/abs/2105.03011
@article{DBLP:journals/corr/abs-2105-03011,
author = {Pradeep Dasigi and
Kyle Lo and
Iz Beltagy and
Arman Cohan and
Noah A. Smith and
Matt Gardner},
title = {A Dataset of Information-Seeking Questions and Answers Anchored in
Research Papers},
journal = {CoRR},
volume = {abs/2105.03011},
year = {2021},
url = {https://arxiv.org/abs/2105.03011},
eprinttype = {arXiv},
eprint = {2105.03011},
timestamp = {Fri, 14 May 2021 12:13:30 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2105-03011.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
from
collections
import
Counter
from
math
import
exp
import
random
import
re
import
string
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
f1_score
,
mean
from
.common
import
HFTask
def
normalize_answer
(
s
):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
Lower text and remove punctuation, articles and extra whitespace.
"""
def
remove_articles
(
text
):
return
re
.
sub
(
r
"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
categorise_answer
(
answer_blob
):
if
answer_blob
[
"unanswerable"
]:
answer
=
"unanswerable"
answer_type
=
"unanswerable"
return
answer
,
answer_type
elif
answer_blob
[
"yes_no"
]:
answer
=
"yes"
answer_type
=
"bool"
return
answer
,
answer_type
elif
answer_blob
[
"free_form_answer"
]:
answer
=
answer_blob
[
"free_form_answer"
]
answer_type
=
"free form answer"
return
answer
,
answer_type
elif
answer_blob
[
"extractive_spans"
]:
answer
=
answer_blob
[
"extractive_spans"
]
answer_type
=
"extractive_spans"
return
answer
,
answer_type
elif
answer_blob
[
"yes_no"
]
is
False
:
answer
=
"no"
answer_type
=
"bool"
return
answer
,
answer_type
def
token_f1_score
(
prediction
,
ground_truth
):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
"""
prediction_tokens
=
normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
normalize_answer
(
ground_truth
).
split
()
common
=
Counter
(
prediction_tokens
)
&
Counter
(
ground_truth_tokens
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
class
QASPER
(
HFTask
):
VERSION
=
0
DATASET_PATH
=
"qasper"
DATASET_NAME
=
None
def
doc_to_text
(
self
,
doc
):
return
(
"TITLE: "
+
doc
[
"title"
]
+
"
\n
"
+
"ABSTRACT: "
+
doc
[
"abstract"
]
+
"
\n\n
"
+
"Q: "
+
doc
[
"question"
]
+
"
\n\n
"
+
"A:"
)
def
doc_to_target
(
self
,
doc
):
answer
=
doc
[
"answer"
]
if
isinstance
(
answer
,
list
):
answer
=
", "
.
join
(
answer
)
return
" "
+
answer
def
training_docs
(
self
):
for
doc
in
self
.
data
[
"train"
]:
yield
from
self
.
process_doc
(
doc
)
def
validation_docs
(
self
):
for
doc
in
self
.
data
[
"train"
]:
yield
from
self
.
process_doc
(
doc
)
def
process_doc
(
self
,
doc
):
"""Given a `doc`, flatten it out so that each JSON blob
contains exactly one question and one answer. Logic taken from
the reference implementation available at
https://github.com/allenai/qasper-led-baseline/blob/main/scripts/evaluator.py
"""
obs_list
=
[]
for
question
,
answer_list
in
zip
(
doc
[
"qas"
][
"question"
],
doc
[
"qas"
][
"answers"
]):
for
answer_blob
in
answer_list
[
"answer"
]:
answer
,
answer_type
=
categorise_answer
(
answer_blob
)
obs_list
.
append
(
{
"title"
:
doc
[
"title"
],
"abstract"
:
doc
[
"abstract"
],
"question"
:
question
,
"answer"
:
answer
,
"answer_type"
:
answer_type
,
}
)
return
obs_list
def
process_results
(
self
,
doc
,
results
):
# TODO: Calculate a score for extractive spans once a request type for generating
# extractive spans is available
if
not
results
:
return
{}
elif
len
(
results
)
==
1
:
[
res
]
=
results
elif
len
(
results
)
==
2
:
[
ll_yes
,
ll_no
]
=
results
# TODO: Handle unanswerability first
# unanswerable_gold = doc["answer_type"] == "unanswerable"
# unanswerable_pred = exp(logprob_unanswerable)
# res_dict["f1_unanswerable"] = (unanswerable_gold, unanswerable_pred)
res_dict
=
{}
# Handle yes/no questions
if
doc
[
"answer_type"
]
==
"bool"
:
gold
=
1
if
doc
[
"answer"
]
==
"yes"
else
0
pred
=
ll_yes
>
ll_no
res_dict
[
"f1_yesno"
]
=
(
gold
,
pred
)
# Handle completions
if
doc
[
"answer_type"
]
==
"free form answer"
:
res_dict
[
"f1_abstractive"
]
=
token_f1_score
(
res
,
doc
[
"answer"
])
# TODO: Handle extraction
# if doc["answer_type"] == "extractive_spans":
# res_dict["f1_extractive"] = 0
return
res_dict
def
aggregation
(
self
):
return
{
"f1_yesno"
:
f1_score
,
"f1_abstractive"
:
mean
,
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if
doc
[
"answer_type"
]
in
(
"free form answer"
):
return
[
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])]
elif
doc
[
"answer_type"
]
in
(
"bool"
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
return
[
ll_yes
,
ll_no
]
else
:
return
[]
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"f1_yesno"
:
True
,
"f1_abstractive"
:
True
,
}
lm_eval/utils.py
View file @
bf19a7a1
import
os
import
os
import
pathlib
import
re
import
re
import
collections
import
collections
import
functools
import
functools
import
inspect
import
inspect
import
sys
import
pytest
from
typing
import
List
class
ExitCodeError
(
Exception
):
class
ExitCodeError
(
Exception
):
...
@@ -155,3 +159,32 @@ def positional_deprecated(fn):
...
@@ -155,3 +159,32 @@ def positional_deprecated(fn):
"lm-evaluation-harness!"
)
"lm-evaluation-harness!"
)
return
fn
(
*
args
,
**
kwargs
)
return
fn
(
*
args
,
**
kwargs
)
return
_wrapper
return
_wrapper
@
positional_deprecated
def
find_test_root
(
start_path
:
pathlib
.
Path
)
->
pathlib
.
Path
:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path
=
start_path
.
resolve
()
max_layers
=
3
for
_
in
range
(
max_layers
):
if
(
cur_path
/
'tests'
/
'test_version_stable.py'
).
exists
():
return
cur_path
else
:
cur_path
=
cur_path
.
parent
.
resolve
()
raise
FileNotFoundError
(
f
"Unable to find package root within
{
max_layers
}
upwards"
+
\
f
"of
{
start_path
}
"
)
@
positional_deprecated
def
run_task_tests
(
task_list
:
List
[
str
]):
"""
Find the package root and run the tests for the given tasks
"""
package_root
=
find_test_root
(
start_path
=
pathlib
.
Path
(
__file__
))
task_string
=
' or '
.
join
(
task_list
)
args
=
[
f
'
{
package_root
}
/tests/test_version_stable.py'
,
f
'--rootdir=
{
package_root
}
'
,
'-k'
,
f
'
{
task_string
}
'
]
sys
.
path
.
append
(
str
(
package_root
))
pytest_return_val
=
pytest
.
main
(
args
)
if
pytest_return_val
:
raise
ValueError
(
f
"Not all tests for the specified tasks (
{
task_list
}
) ran successfully! Error code:
{
pytest_return_val
}
"
)
\ No newline at end of file
main.py
View file @
bf19a7a1
...
@@ -38,7 +38,8 @@ def parse_args():
...
@@ -38,7 +38,8 @@ def parse_args():
parser
.
add_argument
(
'--decontaminate'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--decontaminate'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--decontaminate_ngrams_path'
,
default
=
None
)
parser
.
add_argument
(
'--decontaminate_ngrams_path'
,
default
=
None
)
parser
.
add_argument
(
'--decontaminate_ngrams_n_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--decontaminate_ngrams_n_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--description_dict_path'
,
default
=
None
)
parser
.
add_argument
(
'--description_dict_path'
,
default
=
None
)
parser
.
add_argument
(
'--check_integrity'
,
action
=
"store_true"
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -98,6 +99,7 @@ def main():
...
@@ -98,6 +99,7 @@ def main():
decontaminate
=
args
.
decontaminate
,
decontaminate
=
args
.
decontaminate
,
decontaminate_ngrams_path
=
args
.
decontaminate_ngrams_path
,
decontaminate_ngrams_path
=
args
.
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
=
args
.
decontaminate_ngrams_n_size
decontaminate_ngrams_n_size
=
args
.
decontaminate_ngrams_n_size
check_integrity
=
args
.
check_integrity
)
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
...
...
setup.py
View file @
bf19a7a1
...
@@ -21,7 +21,7 @@ setuptools.setup(
...
@@ -21,7 +21,7 @@ setuptools.setup(
python_requires
=
'>=3.6'
,
python_requires
=
'>=3.6'
,
install_requires
=
[
install_requires
=
[
"black"
,
"black"
,
"best_download
>
=0.0.
6
"
,
"best_download
=
=0.0.
9
"
,
"datasets==1.15.1"
,
"datasets==1.15.1"
,
"click>=7.1"
,
"click>=7.1"
,
"scikit-learn>=0.24.1"
,
"scikit-learn>=0.24.1"
,
...
...
tests/testdata/gsm8k-v0-greedy_until
0 → 100644
View file @
bf19a7a1
e7292dbdd7fd8419ba954f2e0701e04c8d0e8842fe053dbf2fe47d926630e35e
\ No newline at end of file
tests/testdata/gsm8k-v0-res.json
0 → 100644
View file @
bf19a7a1
{
"results"
:
{
"gsm8k"
:
{
"acc"
:
0.0
,
"acc_stderr"
:
0.0
}},
"versions"
:
{
"gsm8k"
:
0
}}
\ No newline at end of file
tests/testdata/math_algebra-v1-greedy_until
0 → 100644
View file @
bf19a7a1
f19182ce697a2c095d9e5b56ee6659dc38c93994b69ca75d7c3d3f5fd87572b4
\ No newline at end of file
tests/testdata/math_algebra-v1-res.json
0 → 100644
View file @
bf19a7a1
{
"results"
:
{
"math_algebra"
:
{
"acc"
:
0.0
,
"acc_stderr"
:
0.0
}},
"versions"
:
{
"math_algebra"
:
1
}}
\ No newline at end of file
tests/testdata/math_counting_and_prob-v1-greedy_until
0 → 100644
View file @
bf19a7a1
2aa9ae43ee9dbb2457525247d7b65358632c5eaa9cbfc40cf95a4f17f5d942ad
\ No newline at end of file
tests/testdata/math_counting_and_prob-v1-res.json
0 → 100644
View file @
bf19a7a1
{
"results"
:
{
"math_counting_and_prob"
:
{
"acc"
:
0.0
,
"acc_stderr"
:
0.0
}},
"versions"
:
{
"math_counting_and_prob"
:
1
}}
\ No newline at end of file
tests/testdata/math_geometry-v1-greedy_until
0 → 100644
View file @
bf19a7a1
46bc4cb219b6903397da782699a684bdbb982c0c954ff82e6beeed5c84878f42
\ No newline at end of file
tests/testdata/math_geometry-v1-res.json
0 → 100644
View file @
bf19a7a1
{
"results"
:
{
"math_geometry"
:
{
"acc"
:
0.0
,
"acc_stderr"
:
0.0
}},
"versions"
:
{
"math_geometry"
:
1
}}
\ No newline at end of file
tests/testdata/math_intermediate_algebra-v1-greedy_until
0 → 100644
View file @
bf19a7a1
d53c699de272d517ed7ad783b4e692302be9f9f97a8d4ac7a6541e538a7cabe0
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment