Commit 328f0e85 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

cleanup metric loading code

parent 6a2620ad
......@@ -156,3 +156,17 @@ def get_aggregation(name):
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
def get_default_aggregation(metric_name):
try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name]
except KeyError:
raise Warning(f"No default aggregation metric for metric '{metric_name}'!")
def is_higher_better(metric_name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
......@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import (
# get_metric,
# get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.api.registry import (
METRIC_REGISTRY,
get_metric,
get_aggregation,
get_default_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
)
ALL_OUTPUT_TYPES = [
......@@ -517,13 +516,11 @@ class ConfigurableTask(Task):
if self._config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_default_aggregation(
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self._config.metric_list:
assert "metric" in metric_config
......@@ -533,30 +530,13 @@ class ConfigurableTask(Task):
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
try:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
except Exception:
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
agg_name
]
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
......@@ -564,7 +544,7 @@ class ConfigurableTask(Task):
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name]
metric_agg = get_default_aggregation(metric_name)
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not. "
f"using default "
......@@ -580,11 +560,9 @@ class ConfigurableTask(Task):
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}"
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self._config.dataset_kwargs)
self._training_docs = None
......@@ -887,7 +865,7 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc)
for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute(
_dict = self._metric_fn_list[key](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
......
......@@ -19,6 +19,6 @@ metric_list:
- metric: acc_norm
aggregation: mean
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
# - metric: acc_mutual_info
# aggregation: mean
# higher_is_better: true
......@@ -8,7 +8,13 @@ training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int.
gold_alias: " {{answer_choices[label]}}" # this will be cast to an int.
generation_kwargs:
until:
- "\n\n"
- "\n"
do_sample: false
temperature: 0.0
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list:
- metric: exact_match
......
group:
- super-glue-promptsource
task: "GPT-3 style"
task: "rte"
dataset_path: super_glue
dataset_name: rte
training_split: train
validation_split: validation
use_prompt: "promptsource:GPT-3 style"
generation_kwargs:
until:
- "\n"
- "\n\n"
metric_list:
- metric: exact_match
aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment