Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d88a566c
Unverified
Commit
d88a566c
authored
Aug 01, 2023
by
Lintang Sutawika
Committed by
GitHub
Aug 01, 2023
Browse files
Merge pull request #612 from EleutherAI/benchmark-scripts
[Refactor] Benchmark scripts
parents
4168c05f
29f12dd9
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
377 additions
and
80 deletions
+377
-80
.gitignore
.gitignore
+1
-0
lm_eval/api/registry.py
lm_eval/api/registry.py
+9
-3
lm_eval/api/task.py
lm_eval/api/task.py
+19
-4
lm_eval/evaluator.py
lm_eval/evaluator.py
+61
-8
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+15
-0
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+112
-17
lm_eval/tasks/benchmarks/pythia.yaml
lm_eval/tasks/benchmarks/pythia.yaml
+12
-0
lm_eval/tasks/benchmarks/t0_eval.yaml
lm_eval/tasks/benchmarks/t0_eval.yaml
+91
-0
lm_eval/tasks/super_glue/boolq/default.yaml
lm_eval/tasks/super_glue/boolq/default.yaml
+1
-1
lm_eval/tasks/super_glue/cb/t5-prompt.yaml
lm_eval/tasks/super_glue/cb/t5-prompt.yaml
+5
-1
lm_eval/tasks/super_glue/copa/t5-prompt.yaml
lm_eval/tasks/super_glue/copa/t5-prompt.yaml
+3
-1
lm_eval/tasks/super_glue/multirc/promptsource-00.yaml
lm_eval/tasks/super_glue/multirc/promptsource-00.yaml
+0
-14
lm_eval/tasks/super_glue/multirc/promptsource-01.yaml
lm_eval/tasks/super_glue/multirc/promptsource-01.yaml
+0
-5
lm_eval/tasks/super_glue/multirc/promptsource-02.yaml
lm_eval/tasks/super_glue/multirc/promptsource-02.yaml
+0
-5
lm_eval/tasks/super_glue/record/default.yaml
lm_eval/tasks/super_glue/record/default.yaml
+6
-2
lm_eval/tasks/super_glue/record/promptsource-00.yaml
lm_eval/tasks/super_glue/record/promptsource-00.yaml
+0
-14
lm_eval/tasks/super_glue/record/t5-prompt.yaml
lm_eval/tasks/super_glue/record/t5-prompt.yaml
+1
-0
lm_eval/tasks/super_glue/record/util.py
lm_eval/tasks/super_glue/record/util.py
+28
-0
lm_eval/tasks/super_glue/rte/default.yaml
lm_eval/tasks/super_glue/rte/default.yaml
+13
-0
lm_eval/tasks/super_glue/rte/promptsource-01.yaml
lm_eval/tasks/super_glue/rte/promptsource-01.yaml
+0
-5
No files found.
.gitignore
View file @
d88a566c
env
*.pyc
output/
data/
lm_cache
.idea
...
...
lm_eval/api/registry.py
View file @
d88a566c
import
os
import
evaluate
from
lm_eval.api.model
import
LM
from
lm_eval.logger
import
eval_logger
MODEL_REGISTRY
=
{}
...
...
@@ -130,7 +131,7 @@ searching in HF Evaluate library..."
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
raise
Warning
(
eval_logger
.
error
(
"{} not found in the evaluate library!"
.
format
(
name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
...
...
@@ -153,7 +154,7 @@ def get_aggregation(name):
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
raise
W
arning
(
eval_logger
.
w
arning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
...
...
@@ -162,7 +163,9 @@ def get_default_aggregation(metric_name):
try
:
return
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
except
KeyError
:
raise
Warning
(
f
"No default aggregation metric for metric '
{
metric_name
}
'!"
)
eval_logger
.
warning
(
f
"No default aggregation metric for metric '
{
metric_name
}
'!"
)
def
is_higher_better
(
metric_name
):
...
...
@@ -170,3 +173,6 @@ def is_higher_better(metric_name):
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
raise
Warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
)
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
)
lm_eval/api/task.py
View file @
d88a566c
...
...
@@ -70,6 +70,7 @@ class TaskConfig(dict):
doc_to_target
:
Union
[
Callable
,
str
]
=
None
doc_to_choice
:
Union
[
Callable
,
str
,
dict
,
list
]
=
None
gold_alias
:
Union
[
Callable
,
str
]
=
None
process_results
:
Union
[
Callable
,
str
]
=
None
use_prompt
:
str
=
None
description
:
str
=
""
target_delimiter
:
str
=
" "
...
...
@@ -545,8 +546,18 @@ class ConfigurableTask(Task):
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
}
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
self
.
_config
.
process_results
is
not
None
:
self
.
_metric_fn_list
[
metric_name
]
=
None
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
elif
callable
(
metric_name
):
metric_fn
=
metric_name
.
__call__
metric_name
=
metric_name
.
__name__
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
else
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
...
...
@@ -885,8 +896,8 @@ class ConfigurableTask(Task):
def
process_results
(
self
,
doc
,
results
):
#
if callable(self._config.process_results):
#
return self._config.process_results(doc, results)
if
callable
(
self
.
_config
.
process_results
):
return
self
.
_config
.
process_results
(
doc
,
results
)
result_dict
=
{}
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
...
...
@@ -975,7 +986,11 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
gold
=
self
.
doc_to_target
(
doc
)
if
type
(
gold
)
==
int
:
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
print
(
self
.
_metric_fn_list
)
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
if
self
.
multiple_target
:
# in the case where we have multiple targets,
...
...
lm_eval/evaluator.py
View file @
d88a566c
...
...
@@ -192,14 +192,35 @@ def evaluate(
# decontaminate = decontamination_ngrams_path is not None
# stores the final result for each task, for each metric/filter pair.
results
=
collections
.
defaultdict
(
dict
)
# Tracks each task's version.
versions
=
collections
.
defaultdict
(
dict
)
# Tracks the YAML configs of all chosen tasks.
configs
=
collections
.
defaultdict
(
dict
)
# logs info about each document evaluated.
samples
=
collections
.
defaultdict
(
list
)
# tracks all Instances/requests a model must generate output on.
requests
=
collections
.
defaultdict
(
list
)
# Stores task scores based on task grouping.
aggregate
=
collections
.
defaultdict
(
dict
)
# tracks if a task was chosen via user selecting a group containing it
task_groups
=
collections
.
defaultdict
(
dict
)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests
=
collections
.
defaultdict
(
int
)
# Stores group related keys and values for group-aggregation
aggregate
=
collections
.
defaultdict
(
dict
)
task_groups
=
collections
.
defaultdict
(
dict
)
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group
,
task
=
task
task_groups
[
task_name
]
=
group
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
...
...
@@ -243,6 +264,7 @@ def evaluate(
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad
=
max
(
gathered_item
)
-
gathered_item
[
lm
.
rank
]
padding_requests
[
task
.
OUTPUT_TYPE
]
+=
numpad
### Run LM on inputs, get all outputs ###
# execute each type of request
...
...
@@ -253,8 +275,8 @@ def evaluate(
for
req
in
reqs
:
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
if
(
lm
.
world_size
>
1
)
and
(
num
pad
>
0
):
for
_
in
range
(
num
pad
):
if
(
lm
.
world_size
>
1
)
and
(
pad
ding_requests
[
reqtype
]
>
0
):
for
_
in
range
(
pad
ding_requests
[
reqtype
]
):
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
# run requests through model
...
...
@@ -264,12 +286,14 @@ def evaluate(
for
x
,
req
in
zip
(
resps
,
cloned_reqs
):
req
.
resps
.
append
(
x
)
if
lm
.
world_size
>
1
:
lm
.
accelerator
.
wait_for_everyone
()
if
lm
.
world_size
>
1
:
lm
.
accelerator
.
wait_for_everyone
()
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group
,
task
=
task
task
.
apply_filters
()
### Collect values of metrics on all datapoints ###
...
...
@@ -277,6 +301,8 @@ def evaluate(
# unpack results and sort back in order and return control to Task
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group
,
task
=
task
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
for
key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
...
...
@@ -362,7 +388,23 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
+
","
+
key
]
=
task
.
aggregation
()[
metric
](
items
)
if
type
(
task
)
==
tuple
:
group
,
task
=
task
task_score
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_name
][
metric
+
","
+
key
]
=
task_score
# Need to put back in results
# pythia | acc
# | perplexity
# | word_perplexity
# | byte_perplexity
# | bits_per_byte
if
bool
(
task_groups
):
group_name
=
task_groups
[
task_name
]
if
metric
not
in
aggregate
[
group_name
]:
aggregate
[
group_name
][
metric
]
=
[
task_score
]
else
:
aggregate
[
group_name
][
metric
].
append
(
task_score
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
...
...
@@ -377,10 +419,21 @@ def evaluate(
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
if
bool
(
aggregate
):
for
group
in
aggregate
.
keys
():
for
metric
in
aggregate
[
group
].
keys
():
aggregate
[
group
][
metric
]
=
np
.
average
(
aggregate
[
group
][
metric
])
versions
[
group
]
=
"N/A"
results_dict
=
{
"results"
:
dict
(
results
),
"configs"
:
dict
(
configs
),
"versions"
:
dict
(
versions
),
"results"
:
dict
(
sorted
(
results
.
items
())),
**
(
{
"aggregate"
:
dict
(
sorted
(
aggregate
.
items
()))}
if
bool
(
aggregate
)
else
{}
),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
}
if
log_samples
:
results_dict
[
"samples"
]
=
dict
(
samples
)
...
...
lm_eval/prompts/__init__.py
View file @
d88a566c
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
# Prompt library.
...
...
@@ -51,3 +52,17 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
f
"expected only a single `:` as separator between
\
prompt category and name, but got `
{
prompt_id
}
` instead"
)
def
load_prompt_list
(
use_prompt
:
str
,
dataset_name
=
None
,
subset_name
=
None
,
**
kwargs
):
from
promptsource.templates
import
DatasetTemplates
if
subset_name
is
None
:
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
)
else
:
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
,
subset_name
=
subset_name
)
category_name
,
prompt_name
=
use_prompt
.
split
(
":"
)
prompt_list
=
utils
.
pattern_match
(
prompt_name
,
prompts
.
all_template_names
)
return
[
":"
.
join
([
category_name
,
prompt
])
for
prompt
in
prompt_list
]
lm_eval/tasks/__init__.py
View file @
d88a566c
import
os
import
yaml
from
typing
import
List
,
Union
from
lm_eval
import
utils
from
lm_eval
import
prompts
from
lm_eval.logger
import
eval_logger
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.registry
import
(
...
...
@@ -13,6 +15,58 @@ from lm_eval.api.registry import (
)
def
register_configurable_task
(
config
):
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
{
"CONFIG"
:
TaskConfig
(
**
config
)},
)
if
"task"
in
config
:
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
if
"group"
in
config
:
if
type
(
config
[
"group"
])
==
str
:
group_name
=
[
config
[
"group"
]]
else
:
group_name
=
config
[
"group"
]
for
group
in
group_name
:
register_group
(
group
)(
SubClass
)
return
0
def
check_prompt_config
(
config
):
all_configs
=
[]
if
"use_prompt"
in
config
:
prompt_list
=
prompts
.
load_prompt_list
(
use_prompt
=
config
[
"use_prompt"
],
dataset_name
=
config
[
"dataset_path"
],
subset_name
=
config
[
"dataset_name"
],
)
for
idx
,
prompt_variation
in
enumerate
(
prompt_list
):
all_configs
.
append
(
{
**
config
,
**
{
"use_prompt"
:
prompt_variation
},
**
{
"task"
:
"_"
.
join
(
[
get_task_name_from_config
(
config
),
prompt_variation
,
]
)
},
**
{
"output_type"
:
"greedy_until"
},
}
)
else
:
all_configs
.
append
(
config
)
return
all_configs
def
get_task_name_from_config
(
task_config
):
if
"dataset_name"
in
task_config
:
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
...
...
@@ -31,23 +85,10 @@ def include_task_folder(task_dir):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
try
:
config
=
utils
.
load_yaml_config
(
yaml_path
)
all_configs
=
check_prompt_config
(
config
)
for
config
in
all_configs
:
register_configurable_task
(
config
)
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
{
"CONFIG"
:
TaskConfig
(
**
config
)},
)
if
"task"
in
config
:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
if
"group"
in
config
:
for
group
in
config
[
"group"
]:
register_group
(
group
)(
SubClass
)
except
Exception
as
error
:
eval_logger
.
warning
(
"Failed to load config in
\n
"
...
...
@@ -57,8 +98,58 @@ def include_task_folder(task_dir):
)
def
include_benchmarks
(
task_dir
,
benchmark_dir
=
"benchmarks"
):
for
root
,
subdirs
,
file_list
in
os
.
walk
(
os
.
path
.
join
(
task_dir
,
benchmark_dir
)):
if
(
subdirs
==
[]
or
subdirs
==
[
"__pycache__"
])
and
(
len
(
file_list
)
>
0
):
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
try
:
benchmark_path
=
os
.
path
.
join
(
root
,
f
)
with
open
(
benchmark_path
,
"rb"
)
as
file
:
yaml_config
=
yaml
.
full_load
(
file
)
assert
"group"
in
yaml_config
group
=
yaml_config
[
"group"
]
all_task_list
=
yaml_config
[
"task"
]
config_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
!=
str
]
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
for
task_config
in
config_list
:
var_configs
=
check_prompt_config
(
{
**
task_config
,
**
{
"group"
:
group
},
}
)
for
config
in
var_configs
:
register_configurable_task
(
config
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
for
task
in
task_names
:
if
task
in
TASK_REGISTRY
:
if
group
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
group
].
append
(
task
)
else
:
GROUP_REGISTRY
[
group
]
=
[
task
]
ALL_TASKS
.
add
(
group
)
except
Exception
as
error
:
eval_logger
.
warning
(
"Failed to load benchmark in
\n
"
f
"
{
benchmark_path
}
\n
"
" Benchmark will not be added to registry
\n
"
f
" Error:
{
error
}
"
)
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
include_task_folder
(
task_dir
)
include_benchmarks
(
task_dir
)
def
get_task
(
task_name
,
config
):
...
...
@@ -97,11 +188,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
if
isinstance
(
task_element
,
str
):
if
task_element
in
GROUP_REGISTRY
:
group_name
=
task_element
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
if
task_name
not
in
task_name_from_registry_dict
:
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
task_name
:
get_task
(
task_name
=
task_name
,
config
=
config
),
task_name
:
(
group_name
,
get_task
(
task_name
=
task_name
,
config
=
config
),
),
}
else
:
task_name
=
task_element
...
...
lm_eval/tasks/benchmarks/pythia.yaml
0 → 100644
View file @
d88a566c
group
:
pythia
task
:
-
lambada_openai
-
wikitext
-
piqa
-
sciq
-
wsc
-
winogrande
-
arc_*
# - logiqa
# - blimp_*
# - hendrycksTest*
lm_eval/tasks/benchmarks/t0_eval.yaml
0 → 100644
View file @
d88a566c
group
:
t0_eval
task
:
# # Coreference Resolution
# - dataset_path: super_glue
# dataset_name: wsc.fixed
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Coreference Resolution
# - dataset_path: winogrande
# dataset_name: winogrande_xl
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# Natural Language Inference
-
dataset_path
:
super_glue
dataset_name
:
cb
use_prompt
:
promptsource:*
training_split
:
train
validation_split
:
validation
output_type
:
greedy_until
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Natural Language Inference
# - dataset_path: super_glue
# dataset_name: rte
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Natural Language Inference
# # - dataset_path: anli
# # use_prompt: promptsource:*
# # training_split: train_r1
# # validation_split: dev_r1
# # Sentence Completion
# - dataset_path: super_glue
# dataset_name: copa
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Natural Language Inference
# - dataset_path: hellaswag
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Word Sense Disambiguation
# - dataset_path: super_glue
# dataset_name: wic
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
lm_eval/tasks/super_glue/boolq/default.yaml
View file @
d88a566c
group
:
-
super-glue-lm-eval-v1
task
:
"
boolq
"
task
:
boolq
dataset_path
:
super_glue
dataset_name
:
boolq
output_type
:
multiple_choice
...
...
lm_eval/tasks/super_glue/cb/t5-prompt.yaml
View file @
d88a566c
...
...
@@ -5,11 +5,15 @@ dataset_path: super_glue
dataset_name
:
cb
training_split
:
train
validation_split
:
validation
output_type
:
greedy_until
doc_to_text
:
"
cb
hypothesis:
{{hypothesis}}
premise
{{premise}}"
doc_to_target
:
"
{%
set
answer_choices
=
['entailment',
'contradiction',
'neutral']
%}{{answer_choices[label]}}"
doc_to_target
:
label
doc_to_choice
:
[
'
entailment'
,
'
contradiction'
,
'
neutral'
]
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
-
metric
:
f1
aggregation
:
!function
"
aggregate.cb_multi_fi"
lm_eval/tasks/super_glue/copa/t5-prompt.yaml
View file @
d88a566c
...
...
@@ -5,8 +5,10 @@ dataset_path: super_glue
dataset_name
:
copa
training_split
:
train
validation_split
:
validation
output_type
:
greedy_until
doc_to_text
:
"
copa
choice1:
{{choice1}}
choice2:
{{choice2}}
question:
{{question}}"
doc_to_target
:
"
{%
set
answer_choices
=
['False',
'True']
%}{{answer_choices[label]}}"
doc_to_target
:
label
doc_to_choice
:
[
'
False'
,
'
True'
]
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
...
...
lm_eval/tasks/super_glue/multirc/promptsource-00.yaml
deleted
100644 → 0
View file @
4168c05f
group
:
-
super-glue-promptsource
task
:
"
I
was
going
to
say…"
dataset_path
:
super_glue
dataset_name
:
multirc
training_split
:
train
validation_split
:
validation
use_prompt
:
"
promptsource:I
was
going
to
say…"
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
lm_eval/tasks/super_glue/multirc/promptsource-01.yaml
deleted
100644 → 0
View file @
4168c05f
include
:
promptsource-00.yaml
group
:
-
super-glue-promptsource
task
:
"
Would
it
be
good
to
answer…"
use_prompt
:
"
promptsource:Would
it
be
good
to
answer…"
lm_eval/tasks/super_glue/multirc/promptsource-02.yaml
deleted
100644 → 0
View file @
4168c05f
include
:
promptsource-00.yaml
group
:
-
super-glue-promptsource
task
:
"
confirm"
use_prompt
:
"
promptsource:confirm"
lm_eval/tasks/super_glue/record/default.yaml
View file @
d88a566c
#
group:
#
- super-glue-lm-eval-v1
group
:
-
super-glue-lm-eval-v1
task
:
record
dataset_path
:
super_glue
dataset_name
:
record
...
...
@@ -9,6 +9,10 @@ validation_split: validation
doc_to_text
:
!function
util.doc_to_text
doc_to_target
:
"
{{answers}}"
doc_to_choice
:
"
{{entities}}"
process_results
:
!function
util.process_results
metric_list
:
-
metric
:
f1
aggregation
:
mean
-
metric
:
em
higher_is_better
:
True
aggregation
:
mean
lm_eval/tasks/super_glue/record/promptsource-00.yaml
deleted
100644 → 0
View file @
4168c05f
group
:
-
super-glue-promptsource
task
:
"
Add
sentence
after
(continuation
choices)"
dataset_path
:
super_glue
dataset_name
:
record
training_split
:
train
validation_split
:
validation
use_prompt
:
"
promptsource:Add
sentence
after
(continuation
choices)"
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
lm_eval/tasks/super_glue/record/t5-prompt.yaml
View file @
d88a566c
...
...
@@ -5,6 +5,7 @@ dataset_path: super_glue
dataset_name
:
record
training_split
:
train
validation_split
:
validation
output_type
:
greedy_until
doc_to_text
:
"
record
query:
{{query}}
entities:
{{entities}}
passage:
{{passage}}"
doc_to_target
:
"
{{answers}}"
metric_list
:
...
...
lm_eval/tasks/super_glue/record/util.py
View file @
d88a566c
import
numpy
as
np
import
transformers.data.metrics.squad_metrics
as
squad_metrics
from
lm_eval.api.metrics
import
metric_max_over_ground_truths
def
doc_to_text
(
doc
):
initial_text
,
*
highlights
=
doc
[
"passage"
].
strip
().
split
(
"
\n
@highlight
\n
"
)
text
=
initial_text
+
"
\n\n
"
...
...
@@ -13,3 +19,25 @@ def format_answer(query, entity):
def
doc_to_target
(
doc
):
# We only output the first correct entity in a doc
return
format_answer
(
query
=
doc
[
"query"
],
entity
=
doc
[
"answers"
][
0
])
def
process_results
(
doc
,
results
):
# ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
max_idx
=
np
.
argmax
(
np
.
array
([
result
[
0
]
for
result
in
results
]))
prediction
=
doc
[
"entities"
][
max_idx
]
gold_label_set
=
doc
[
"answers"
]
f1
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_f1
,
prediction
,
gold_label_set
)
em
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_exact
,
prediction
,
gold_label_set
)
return
{
"f1"
:
f1
,
"em"
:
em
,
}
lm_eval/tasks/super_glue/rte/
promptsource-00
.yaml
→
lm_eval/tasks/super_glue/rte/
default
.yaml
View file @
d88a566c
group
:
-
super-glue-
promptsource
task
:
"
rte
"
-
super-glue-
lm-eval-v1
task
:
rte
dataset_path
:
super_glue
dataset_name
:
rte
output_type
:
multiple_choice
training_split
:
train
validation_split
:
validation
use_prompt
:
"
promptsource:GPT-3
style"
generation_kwargs
:
until
:
-
"
\n
"
-
"
\n\n
"
doc_to_text
:
"
{{premise}}
\n
Question:
{{hypothesis}}
True
or
False?
\n
Answer:"
doc_to_target
:
label
doc_to_choice
:
[
'
True'
,
'
False'
]
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
-
metric
:
acc
lm_eval/tasks/super_glue/rte/promptsource-01.yaml
deleted
100644 → 0
View file @
4168c05f
include
:
promptsource-00.yaml
group
:
-
super-glue-promptsource
task
:
"
MNLI
crowdsource"
use_prompt
:
"
promptsource:MNLI
crowdsource"
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment