Unverified Commit 408115ea authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #935 from EleutherAI/fix-default-metric-call

parents cf617ab1 1428ad57
......@@ -5,6 +5,7 @@ import numpy as np
import sacrebleu
import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation
......@@ -135,6 +136,19 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items
exact_match = evaluate.load("exact_match")
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
@register_metric(
metric="perplexity",
higher_is_better=False,
......
......@@ -68,10 +68,10 @@ def register_group(name):
return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
......@@ -95,8 +95,7 @@ def register_metric(**args):
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
......@@ -158,12 +157,13 @@ def get_aggregation(name):
)
def get_default_aggregation(metric_name):
def get_metric_aggregation(name):
try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name]
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
f"No default aggregation metric for metric '{metric_name}'!"
"{} metric is not assigned a default aggregation!".format(name),
)
......
......@@ -33,7 +33,7 @@ from lm_eval.api.metrics import (
from lm_eval.api.registry import (
get_metric,
get_aggregation,
get_default_aggregation,
get_metric_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
......@@ -538,12 +538,14 @@ class ConfigurableTask(Task):
self._aggregation_list = {}
self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_default_aggregation(
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
......@@ -586,7 +588,7 @@ class ConfigurableTask(Task):
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name)
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
......@@ -687,7 +689,10 @@ class ConfigurableTask(Task):
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = (
True if self.config.target_delimiter[-1].isspace() else False
True
if self.config.target_delimiter.rstrip()
== self.config.target_delimiter
else False
)
if delimiter_has_whitespace and choice_has_whitespace:
......
......@@ -663,8 +663,16 @@ class HFLM(LM):
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None,
n=self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None,
)
for chunk in tqdm(chunks, disable=(disable_tqdm or (self.rank != 0))):
......@@ -840,8 +848,14 @@ class HFLM(LM):
for key, re_ord in re_ords.items():
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0,
fn=self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None,
n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
)
for chunk in tqdm(chunks, disable=self.rank != 0):
contexts, all_gen_kwargs = zip(*chunk)
......
......@@ -15,7 +15,8 @@ from lm_eval.api.registry import (
import logging
eval_logger = logging.getLogger('lm-eval')
eval_logger = logging.getLogger("lm-eval")
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
......
......@@ -9,4 +9,4 @@ task:
- wsc
- ai2_arc
- blimp
- hendrycksTest*
- mmlu
......@@ -8,7 +8,8 @@ training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label
doc_to_choice: ['no', 'yes']
doc_to_choice: [' no', ' yes']
target_delimiter: ""
generation_kwargs:
until:
- "\n\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment