Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
328f0e85
Commit
328f0e85
authored
Jul 03, 2023
by
haileyschoelkopf
Browse files
cleanup metric loading code
parent
6a2620ad
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
42 deletions
+44
-42
lm_eval/api/registry.py
lm_eval/api/registry.py
+14
-0
lm_eval/api/task.py
lm_eval/api/task.py
+15
-37
lm_eval/tasks/arc/arc_challenge.yaml
lm_eval/tasks/arc/arc_challenge.yaml
+3
-3
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
+7
-1
lm_eval/tasks/super_glue/rte/promptsource-00.yaml
lm_eval/tasks/super_glue/rte/promptsource-00.yaml
+5
-1
No files found.
lm_eval/api/registry.py
View file @
328f0e85
...
...
@@ -156,3 +156,17 @@ def get_aggregation(name):
raise
Warning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
def
get_default_aggregation
(
metric_name
):
try
:
return
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
except
KeyError
:
raise
Warning
(
f
"No default aggregation metric for metric '
{
metric_name
}
'!"
)
def
is_higher_better
(
metric_name
):
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
raise
Warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
)
lm_eval/api/task.py
View file @
328f0e85
...
...
@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger
from
lm_eval.prompts
import
get_prompt
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.api.metrics
import
(
# get_metric,
# get_aggregation,
mean
,
weighted_perplexity
,
bits_per_byte
,
)
from
lm_eval.api.registry
import
(
METRIC_REGISTRY
,
get_metric
,
get_aggregation
,
get_default_aggregation
,
is_higher_better
,
DEFAULT_METRIC_REGISTRY
,
OUTPUT_TYPE_REGISTRY
,
AGGREGATION_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
,
DEFAULT_AGGREGATION_REGISTRY
,
)
ALL_OUTPUT_TYPES
=
[
...
...
@@ -517,13 +516,11 @@ class ConfigurableTask(Task):
if
self
.
_config
.
metric_list
is
None
:
# TODO: handle this in TaskConfig.__post_init__ ?
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_aggregation_list
[
metric_name
]
=
DEFAULT_AGGREGATION_REGISTRY
[
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_aggregation_list
[
metric_name
]
=
get_default_aggregation
(
metric_name
]
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
for
metric_config
in
self
.
_config
.
metric_list
:
assert
"metric"
in
metric_config
...
...
@@ -533,30 +530,13 @@ class ConfigurableTask(Task):
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
}
try
:
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
except
Exception
:
eval_logger
.
warning
(
f
"Metric
{
metric_name
}
not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try
:
metric_object
=
evaluate
.
load
(
metric_name
)
self
.
_metric_fn_list
[
metric_name
]
=
metric_object
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
except
Exception
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
metric_name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
if
type
(
agg_name
)
==
str
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
agg_name
]
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
agg_name
)
elif
callable
(
agg_name
):
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
...
...
@@ -564,7 +544,7 @@ class ConfigurableTask(Task):
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
metric_agg
=
get_default_aggregation
(
metric_name
)
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default "
...
...
@@ -580,11 +560,9 @@ class ConfigurableTask(Task):
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default "
f
"higher_is_better=
{
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
}
"
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
)
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
download
(
self
.
_config
.
dataset_kwargs
)
self
.
_training_docs
=
None
...
...
@@ -887,7 +865,7 @@ class ConfigurableTask(Task):
gold
=
self
.
doc_to_target
(
doc
)
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
_dict
=
self
.
_metric_fn_list
[
key
]
.
compute
(
_dict
=
self
.
_metric_fn_list
[
key
](
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
],
...
...
lm_eval/tasks/arc/arc_challenge.yaml
View file @
328f0e85
...
...
@@ -19,6 +19,6 @@ metric_list:
-
metric
:
acc_norm
aggregation
:
mean
higher_is_better
:
true
-
metric
:
acc_mutual_info
aggregation
:
mean
higher_is_better
:
true
#
- metric: acc_mutual_info
#
aggregation: mean
#
higher_is_better: true
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
View file @
328f0e85
...
...
@@ -8,7 +8,13 @@ training_split: train
validation_split
:
validation
doc_to_text
:
"
{{passage}}
\n
Question:
{{question}}
\n
Answer:"
doc_to_target
:
"
{{answer_choices[label]}}"
gold_alias
:
"
{{label}}"
# this will be cast to an int.
gold_alias
:
"
{{answer_choices[label]}}"
# this will be cast to an int.
generation_kwargs
:
until
:
-
"
\n\n
"
-
"
\n
"
do_sample
:
false
temperature
:
0.0
template_aliases
:
"
{%
set
answer_choices
=
['no',
'yes']
%}"
metric_list
:
-
metric
:
exact_match
...
...
lm_eval/tasks/super_glue/rte/promptsource-00.yaml
View file @
328f0e85
group
:
-
super-glue-promptsource
task
:
"
GPT-3
styl
e"
task
:
"
rt
e"
dataset_path
:
super_glue
dataset_name
:
rte
training_split
:
train
validation_split
:
validation
use_prompt
:
"
promptsource:GPT-3
style"
generation_kwargs
:
until
:
-
"
\n
"
-
"
\n\n
"
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment