Commit 328f0e85 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

cleanup metric loading code

parent 6a2620ad
...@@ -156,3 +156,17 @@ def get_aggregation(name): ...@@ -156,3 +156,17 @@ def get_aggregation(name):
raise Warning( raise Warning(
"{} not a registered aggregation metric!".format(name), "{} not a registered aggregation metric!".format(name),
) )
def get_default_aggregation(metric_name):
try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name]
except KeyError:
raise Warning(f"No default aggregation metric for metric '{metric_name}'!")
def is_higher_better(metric_name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
...@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger ...@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
# get_metric,
# get_aggregation,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
METRIC_REGISTRY, get_metric,
get_aggregation,
get_default_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY, OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
) )
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -517,13 +516,11 @@ class ConfigurableTask(Task): ...@@ -517,13 +516,11 @@ class ConfigurableTask(Task):
if self._config.metric_list is None: if self._config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ? # TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = get_default_aggregation(
metric_name metric_name
] )
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ self._higher_is_better[metric_name] = is_higher_better(metric_name)
metric_name
]
else: else:
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
...@@ -533,30 +530,13 @@ class ConfigurableTask(Task): ...@@ -533,30 +530,13 @@ class ConfigurableTask(Task):
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
try: self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if type(agg_name) == str: if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = get_aggregation(agg_name)
agg_name
]
elif callable(agg_name): elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[ self._aggregation_list[metric_name] = metric_config[
"aggregation" "aggregation"
...@@ -564,7 +544,7 @@ class ConfigurableTask(Task): ...@@ -564,7 +544,7 @@ class ConfigurableTask(Task):
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name] metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not. " f"metric {metric_name} is defined, but aggregation is not. "
f"using default " f"using default "
...@@ -580,11 +560,9 @@ class ConfigurableTask(Task): ...@@ -580,11 +560,9 @@ class ConfigurableTask(Task):
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not. " f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default " f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}" f"higher_is_better={is_higher_better(metric_name)}"
) )
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ self._higher_is_better[metric_name] = is_higher_better(metric_name)
metric_name
]
self.download(self._config.dataset_kwargs) self.download(self._config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -887,7 +865,7 @@ class ConfigurableTask(Task): ...@@ -887,7 +865,7 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[key], **self._metric_fn_kwargs[key],
......
...@@ -19,6 +19,6 @@ metric_list: ...@@ -19,6 +19,6 @@ metric_list:
- metric: acc_norm - metric: acc_norm
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
- metric: acc_mutual_info # - metric: acc_mutual_info
aggregation: mean # aggregation: mean
higher_is_better: true # higher_is_better: true
...@@ -8,7 +8,13 @@ training_split: train ...@@ -8,7 +8,13 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}" doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int. gold_alias: " {{answer_choices[label]}}" # this will be cast to an int.
generation_kwargs:
until:
- "\n\n"
- "\n"
do_sample: false
temperature: 0.0
template_aliases: "{% set answer_choices = ['no', 'yes'] %}" template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list: metric_list:
- metric: exact_match - metric: exact_match
......
group: group:
- super-glue-promptsource - super-glue-promptsource
task: "GPT-3 style" task: "rte"
dataset_path: super_glue dataset_path: super_glue
dataset_name: rte dataset_name: rte
training_split: train training_split: train
validation_split: validation validation_split: validation
use_prompt: "promptsource:GPT-3 style" use_prompt: "promptsource:GPT-3 style"
generation_kwargs:
until:
- "\n"
- "\n\n"
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment