Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88745155
Commit
88745155
authored
Apr 25, 2022
by
cjlovering
Browse files
Initial integration
parent
6caa0afd
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
318 deletions
+480
-318
lm_eval/base.py
lm_eval/base.py
+220
-96
lm_eval/evaluator.py
lm_eval/evaluator.py
+70
-35
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+42
-35
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+39
-60
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+42
-30
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+67
-62
No files found.
lm_eval/base.py
View file @
88745155
This diff is collapsed.
Click to expand it.
lm_eval/evaluator.py
View file @
88745155
...
...
@@ -6,21 +6,33 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
promptsource
import
numpy
as
np
from
promptsource.templates
import
DatasetTemplates
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
):
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...
...
@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
...
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
'batch_size'
:
batch_size
,
'device'
:
device
})
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
_promptsource
(
tasks
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
...
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
)
# add info about the model and few shot config
...
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
return
results
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
...
...
@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
...
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items
=
[
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
results
=
collections
.
defaultdict
(
dict
)
...
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
...
@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
...
...
@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
def
make_table
(
result_dict
):
...
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
...
...
lm_eval/tasks/__init__.py
View file @
88745155
from
promptsource.templates
import
DatasetTemplates
from
pprint
import
pprint
from
typing
import
List
,
Union
...
...
@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
...
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
# 319 total
...
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
...
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
...
@@ -140,7 +134,7 @@ TASK_REGISTRY = {
"squad2"
:
squad
.
SQuAD2
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa_es"
:
headqa
.
HeadQAEs
,
"headqa_en"
:
headqa
.
HeadQAEn
,
"mathqa"
:
mathqa
.
MathQA
,
...
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
...
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
...
...
@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
for
name
,
class_
in
TASK_REGISTRY
.
items
():
if
class_
is
task_object
:
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
def
get_task_dict_promptsource
(
task_name_list
:
List
[
str
]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict
=
{}
for
task_name
in
task_name_list
:
assert
isinstance
(
task_name
,
str
)
task_prompts
=
DatasetTemplates
(
task_name
)
for
prompt_name
in
task_prompts
.
all_template_names
:
prompt
=
task_prompts
[
prompt_name
]
# NOTE: We choose a sep that can be easily split.
task_name_dict
[
f
"
{
task_name
}
+
{
prompt_name
}
"
]
=
get_task
(
task_name
)(
prompt
=
prompt
)
return
task_name_dict
lm_eval/tasks/coqa.py
View file @
88745155
...
...
@@ -51,44 +51,22 @@ class CoQA(Task):
def
test_docs
(
self
):
pass
def
doc_to_text
(
self
,
doc
):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
# @classmethod
# def get_answers(cls, doc, turn_id):
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers")
# if additional_answers:
# for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
# turn_id - 1
# ]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
# answers.append(additional_answer_for_turn)
# return answers
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
...
...
@@ -98,40 +76,38 @@ class CoQA(Task):
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'answers'
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
...
...
@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
"f1"
:
scores
[
'
f1
'
],
"em"
:
scores
[
'
em
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
"
em
"
],
}
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
88745155
...
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
}
)
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
...
...
@@ -100,15 +105,17 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
return
(
" "
.
join
(
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
#
def doc_to_text(self, doc):
#
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
#
def doc_to_target(self, doc):
#
return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
The results of the requests created in construct_requests.
"""
preds
,
golds
=
results
,
doc
[
"answers"
]
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
preds
=
[
pred
]
golds
=
[
target
]
max_em
=
0
max_f1
=
0
for
gold_answer
in
golds
:
...
...
@@ -142,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
"""
...
...
@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
else
:
exact_match
=
0.0
...
...
@@ -190,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
...
@@ -256,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
...
@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"em"
:
mean
,
"f1"
:
mean
}
return
{
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"em"
:
True
,
"f1"
:
True
}
return
{
"em"
:
True
,
"f1"
:
True
}
lm_eval/tasks/race.py
View file @
88745155
...
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
return
True
...
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
'article'
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
'question'
:
y
[
'question'
],
'answer'
:
y
[
'answer'
],
'options'
:
y
[
'options'
],
})
}))
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
"article"
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
"article"
:
x
[
0
][
"article"
],
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
return
res
...
...
@@ -85,49 +95,48 @@ class RACE(Task):
@
classmethod
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
return
problem
[
'
options
'
][
answer
]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
"
options
"
][
answer
]
@
classmethod
def
last_problem
(
cls
,
doc
):
return
doc
[
'problems'
][
-
1
]
def
doc_to_text
(
self
,
doc
):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
for
problem
in
doc
[
'problems'
][:
-
1
]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
else
:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
question
+
answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
return
text
def
doc_to_target
(
self
,
doc
):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
problem
=
self
.
last_problem
(
doc
)
ll_choices
=
[
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
for
i
in
range
(
4
)
]
return
ll_choices
return
doc
[
"problems"
][
-
1
]
# def doc_to_text(self, doc):
# text = 'Article: ' + doc['article'] + '\n\n'
# for problem in doc['problems'][:-1]:
# if problem['question'][-6:] == ' _ .':
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
# else:
# question = 'Question: ' + problem['question'] + '\n'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
# text += question + answer
# text += self.last_problem(doc)['question']
# return text
# def doc_to_target(self, doc):
# return " " + self.get_answer_option(self.last_problem(doc))
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# problem = self.last_problem(doc)
# ll_choices = [
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
# ]
# return ll_choices
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
...
...
@@ -135,28 +144,24 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'answer'
]]
#
gold
=
self
.
letter_to_num
[
self
.
doc_to_target
(
doc
)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
int
(
pred
==
gold
)
}
return
{
"acc"
:
int
(
pred
==
gold
)}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment