Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88745155
Commit
88745155
authored
Apr 25, 2022
by
cjlovering
Browse files
Initial integration
parent
6caa0afd
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
318 deletions
+480
-318
lm_eval/base.py
lm_eval/base.py
+220
-96
lm_eval/evaluator.py
lm_eval/evaluator.py
+70
-35
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+42
-35
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+39
-60
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+42
-30
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+67
-62
No files found.
lm_eval/base.py
View file @
88745155
This diff is collapsed.
Click to expand it.
lm_eval/evaluator.py
View file @
88745155
...
@@ -6,21 +6,33 @@ import lm_eval.metrics
...
@@ -6,21 +6,33 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.tasks
import
lm_eval.base
import
lm_eval.base
import
promptsource
import
numpy
as
np
import
numpy
as
np
from
promptsource.templates
import
DatasetTemplates
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
def
simple_evaluate
(
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
model
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
model_args
=
None
,
description_dict
=
None
,
check_integrity
=
False
):
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...
@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters:
:param bootstrap_iters:
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
Whether to run the relevant part of the test suite for the tasks
:return
:return
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
if
model_args
is
None
:
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
model_args
=
""
'batch_size'
:
batch_size
,
'device'
:
device
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
})
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
lm
=
model
if
not
no_cache
:
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
_promptsource
(
tasks
)
if
check_integrity
:
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
run_task_tests
(
task_list
=
tasks
)
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
)
)
# add info about the model and few shot config
# add info about the model and few shot config
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
}
return
results
return
results
@
positional_deprecated
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
:param lm: obj
...
@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters:
:param bootstrap_iters:
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:return
:return
Dictionary of results
Dictionary of results
"""
"""
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items
=
[
task_dict_items
=
[
(
name
,
task
)
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
]
results
=
collections
.
defaultdict
(
dict
)
results
=
collections
.
defaultdict
(
dict
)
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
vals
=
collections
.
defaultdict
(
list
)
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
# unpack results and sort back in order and return control to Task
...
@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics
=
task
.
process_results
(
doc
,
requests
)
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
vals
[(
task_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
metric
=
task
.
aggregation
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
def
make_table
(
result_dict
):
def
make_table
(
result_dict
):
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
k
=
""
version
=
""
version
=
""
md_writer
.
value_matrix
=
values
md_writer
.
value_matrix
=
values
...
...
lm_eval/tasks/__init__.py
View file @
88745155
from
promptsource.templates
import
DatasetTemplates
from
pprint
import
pprint
from
pprint
import
pprint
from
typing
import
List
,
Union
from
typing
import
List
,
Union
...
@@ -58,8 +60,8 @@ from . import storycloze
...
@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
# 6 total
gpt3_translation_benchmarks
=
{
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
}
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
}
# 319 total
# 319 total
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
# SuperGLUE
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
@@ -140,7 +134,7 @@ TASK_REGISTRY = {
...
@@ -140,7 +134,7 @@ TASK_REGISTRY = {
"squad2"
:
squad
.
SQuAD2
,
"squad2"
:
squad
.
SQuAD2
,
"race"
:
race
.
RACE
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa_es"
:
headqa
.
HeadQAEs
,
"headqa_es"
:
headqa
.
HeadQAEs
,
"headqa_en"
:
headqa
.
HeadQAEn
,
"headqa_en"
:
headqa
.
HeadQAEn
,
"mathqa"
:
mathqa
.
MathQA
,
"mathqa"
:
mathqa
.
MathQA
,
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
"pile_books3"
:
pile
.
PileBooks3
,
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "storycloze_2018": storycloze.StoryCloze2018,
...
@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
...
@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
for
name
,
class_
in
TASK_REGISTRY
.
items
():
for
name
,
class_
in
TASK_REGISTRY
.
items
():
if
class_
is
task_object
:
if
class_
is
task_object
:
return
name
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
}
task_name_from_object_dict
=
{
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
def
get_task_dict_promptsource
(
task_name_list
:
List
[
str
]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict
=
{}
for
task_name
in
task_name_list
:
assert
isinstance
(
task_name
,
str
)
task_prompts
=
DatasetTemplates
(
task_name
)
for
prompt_name
in
task_prompts
.
all_template_names
:
prompt
=
task_prompts
[
prompt_name
]
# NOTE: We choose a sep that can be easily split.
task_name_dict
[
f
"
{
task_name
}
+
{
prompt_name
}
"
]
=
get_task
(
task_name
)(
prompt
=
prompt
)
return
task_name_dict
lm_eval/tasks/coqa.py
View file @
88745155
...
@@ -51,44 +51,22 @@ class CoQA(Task):
...
@@ -51,44 +51,22 @@ class CoQA(Task):
def
test_docs
(
self
):
def
test_docs
(
self
):
pass
pass
def
doc_to_text
(
self
,
doc
):
# @classmethod
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# def get_answers(cls, doc, turn_id):
# and a question qi, the task is to predict the answer ai
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
# answers = []
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
question
=
f
"Q:
{
q
}
\n\n
"
# answers.append(answer_forturn)
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
# additional_answers = doc.get("additional_answers")
return
doc_text
# if additional_answers:
# for key in additional_answers:
@
classmethod
# additional_answer_for_turn = additional_answers[key]["input_text"][
def
get_answers
(
cls
,
doc
,
turn_id
):
# turn_id - 1
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# ]
answers
=
[]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
# answers.append(additional_answer_for_turn)
answers
.
append
(
answer_forturn
)
# return answers
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
@
staticmethod
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
def
compute_scores
(
gold_list
,
pred
):
...
@@ -98,40 +76,38 @@ class CoQA(Task):
...
@@ -98,40 +76,38 @@ class CoQA(Task):
em_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
# Default to prediction of last turn.
}
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'answers'
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
return
cont_request
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
:param doc:
:param doc:
...
@@ -139,15 +115,18 @@ class CoQA(Task):
...
@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
target
=
self
.
doc_to_target
(
doc
).
strip
()
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
return
{
"f1"
:
scores
[
'
f1
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
'
em
'
],
"em"
:
scores
[
"
em
"
],
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
88745155
...
@@ -70,21 +70,26 @@ class DROP(Task):
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
@
classmethod
def
get_answers
(
cls
,
qa
):
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
"""
vas
=
[]
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
vas
.
append
(
"number"
:
validated_answers
[
"number"
][
i
],
{
"date"
:
validated_answers
[
"date"
][
i
],
"number"
:
validated_answers
[
"number"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
})
"spans"
:
validated_answers
[
"spans"
][
i
],
}
)
return
vas
return
vas
answers
=
[]
answers
=
[]
answers_set
=
set
()
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
if
answer
in
answers_set
:
...
@@ -100,15 +105,17 @@ class DROP(Task):
...
@@ -100,15 +105,17 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
return
(
answer
[
"date"
][
"month"
],
" "
.
join
(
answer
[
"date"
][
"year"
]]).
strip
(),)
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
#
def doc_to_text(self, doc):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
#
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def
doc_to_target
(
self
,
doc
):
#
def doc_to_target(self, doc):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
#
return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
...
@@ -134,7 +141,13 @@ class DROP(Task):
...
@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
preds
,
golds
=
results
,
doc
[
"answers"
]
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
preds
=
[
pred
]
golds
=
[
target
]
max_em
=
0
max_em
=
0
max_f1
=
0
max_f1
=
0
for
gold_answer
in
golds
:
for
gold_answer
in
golds
:
...
@@ -142,10 +155,7 @@ class DROP(Task):
...
@@ -142,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
def
get_metrics
(
self
,
predicted
,
gold
):
"""
"""
...
@@ -158,7 +168,9 @@ class DROP(Task):
...
@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
exact_match
=
1.0
else
:
else
:
exact_match
=
0.0
exact_match
=
0.0
...
@@ -190,7 +202,9 @@ class DROP(Task):
...
@@ -190,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
@@ -256,7 +270,11 @@ class DROP(Task):
...
@@ -256,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
def
_normalize
(
self
,
answer
):
tokens
=
[
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
for
token
in
self
.
_tokenize
(
answer
)
]
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
@@ -269,10 +287,7 @@ class DROP(Task):
...
@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"em"
:
mean
,
"f1"
:
mean
}
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -280,7 +295,4 @@ class DROP(Task):
...
@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"em"
:
True
,
"f1"
:
True
}
"em"
:
True
,
"f1"
:
True
}
lm_eval/tasks/race.py
View file @
88745155
...
@@ -40,7 +40,7 @@ class RACE(Task):
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
DATASET_NAME
=
"high"
cache
=
{}
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -59,17 +59,27 @@ class RACE(Task):
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
for
item
in
datasets
.
load_dataset
(
r
[
item
[
'article'
]].
append
(
item
)
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
r
[
item
[
"article"
]].
append
(
item
)
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
res
=
list
(
'question'
:
y
[
'question'
],
r
.
values
()
'answer'
:
y
[
'answer'
],
>>
each
(
'options'
:
y
[
'options'
],
lambda
x
:
{
})
"article"
:
x
[
0
][
"article"
],
}))
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
self
.
cache
[
set
]
=
res
return
res
return
res
...
@@ -85,49 +95,48 @@ class RACE(Task):
...
@@ -85,49 +95,48 @@ class RACE(Task):
@
classmethod
@
classmethod
def
get_answer_option
(
cls
,
problem
):
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
'
options
'
][
answer
]
return
problem
[
"
options
"
][
answer
]
@
classmethod
@
classmethod
def
last_problem
(
cls
,
doc
):
def
last_problem
(
cls
,
doc
):
return
doc
[
'problems'
][
-
1
]
return
doc
[
"problems"
][
-
1
]
def
doc_to_text
(
self
,
doc
):
# def doc_to_text(self, doc):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
# text = 'Article: ' + doc['article'] + '\n\n'
for
problem
in
doc
[
'problems'
][:
-
1
]:
# for problem in doc['problems'][:-1]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
# if problem['question'][-6:] == ' _ .':
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else
:
# else:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
# question = 'Question: ' + problem['question'] + '\n'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text
+=
question
+
answer
# text += question + answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
# text += self.last_problem(doc)['question']
return
text
# return text
def
doc_to_target
(
self
,
doc
):
# def doc_to_target(self, doc):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
# return " " + self.get_answer_option(self.last_problem(doc))
def
construct_requests
(
self
,
doc
,
ctx
):
# def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
# """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
# Requests which will be sent to the LM.
:param doc:
# :param doc:
The document as returned from training_docs, validation_docs, or test_docs.
# The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
# :param ctx: str
The context string, generated by fewshot_context. This includes the natural
# The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
# language description, as well as the few shot examples, and the question
part of the document for `doc`.
# part of the document for `doc`.
"""
# """
problem
=
self
.
last_problem
(
doc
)
# problem = self.last_problem(doc)
ll_choices
=
[
# ll_choices = [
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for
i
in
range
(
4
)
# ]
]
# return ll_choices
return
ll_choices
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
:param doc:
:param doc:
...
@@ -135,28 +144,24 @@ class RACE(Task):
...
@@ -135,28 +144,24 @@ class RACE(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'answer'
]]
#
gold
=
self
.
letter_to_num
[
self
.
doc_to_target
(
doc
)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
int
(
pred
==
gold
)}
"acc"
:
int
(
pred
==
gold
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [float] -> float}
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
:returns: {str: bool}
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment