Commit aaf64aab authored by lintangsutawika's avatar lintangsutawika
Browse files

readded suport for aggregation

parent 439dca55
......@@ -8,20 +8,23 @@ import numpy as np
import sacrebleu
import sklearn.metrics
from lm_eval.api.registry import register_metric
from lm_eval.api.registry import register_metric, register_aggregation
eval_logger = logging.getLogger("lm-eval")
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_aggregation("weighted_mean")
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
......@@ -161,6 +164,7 @@ def acc_mutual_info_fn(items):
exact_match = evaluate.load("exact_match")
@register_metric(
metric="exact_match",
higher_is_better=True,
......
import os
import logging
import evaluate
import collections
from functools import partial
from lm_eval.api.model import LM
......@@ -9,21 +10,6 @@ eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {}
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(
references=refs,
predictions=preds
)[self.name]
def register_model(*names):
# either pass a list or a single alias.
......@@ -87,8 +73,8 @@ def register_group(name):
return decorate
METRIC_FUNCTION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
METRIC_REGISTRY = collections.defaultdict(dict)
AGGREGATION_REGISTRY = collections.defaultdict(dict)
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [],
......@@ -102,6 +88,7 @@ def register_metric(
metric=None,
higher_is_better=None,
output_type=None,
aggregation=None,
):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
......@@ -112,10 +99,13 @@ def register_metric(
metric_list = metric
for _metric in metric_list:
METRIC_FUNCTION_REGISTRY[_metric] = fn
METRIC_REGISTRY[_metric]["function"] = fn
if aggregation is not None:
METRIC_REGISTRY[_metric]["aggregation"] = aggregation
if higher_is_better is not None:
HIGHER_IS_BETTER_REGISTRY[_metric] = higher_is_better
METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
if output_type is not None:
if type(output_type) == str:
......@@ -131,18 +121,33 @@ def register_metric(
return decorate
def get_metric(name, hf_evaluate_metric=False, **kwargs):
def get_metric(name):
if not hf_evaluate_metric:
if name in METRIC_FUNCTION_REGISTRY:
return METRIC_FUNCTION_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.error(f"Could not find registered metric '{name}' in lm-eval")
def get_evaluate(name, **kwargs):
try:
# from lm_eval.metrics import HFEvaluateAdaptor
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(references=refs, predictions=preds)[
self.name
]
return HFEvaluateAdaptor(name, **kwargs)
except Exception:
eval_logger.error(
......@@ -150,10 +155,22 @@ def get_metric(name, hf_evaluate_metric=False, **kwargs):
)
def is_higher_better(metric_name):
def register_aggregation(name):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
"{} not a registered aggregation metric!".format(name),
)
......@@ -32,7 +32,9 @@ from lm_eval.api.metrics import (
)
from lm_eval.api.registry import (
get_metric,
is_higher_better,
get_evaluate,
get_aggregation,
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
)
......@@ -410,7 +412,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def compute_metric(self):
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
......@@ -553,6 +555,7 @@ class ConfigurableTask(Task):
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
......@@ -561,12 +564,14 @@ class ConfigurableTask(Task):
for metric_name in _metric_list:
metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric
self._metric_fn_list[metric_name] = metric["function"]
self._metric_fn_kwargs[metric_name] = {}
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._aggregation_list = metric["aggregation"]
self._higher_is_better[metric_name] = metric["is_higher_better"]
else:
for metric_config in self.config.metric_list:
assert "metric" in metric_config
from_registry = False
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
......@@ -574,25 +579,47 @@ class ConfigurableTask(Task):
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
use_hf_evaluate = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
# if self.config.process_results is not None:
# self._metric_fn_list[metric_name] = None
# self._metric_fn_kwargs[metric_name] = {}
if callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
else:
metric_fn = get_metric(
metric_name, hf_evaluate_metric, **kwargs
)
assert type(metric_name) == str
if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY:
from_registry = True
metric = get_metric(metric_name, **kwargs)
metric_fn = metric["function"]
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = metric_fn
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = agg_name
else:
if from_registry:
if "aggregation" in metric:
self._aggregation_list[metric_name] = metric["aggregation"]
else:
self._aggregation_list[metric_name] = metric_fn
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
if from_registry:
self._higher_is_better[metric_name] = metric["higher_is_better"]
self.download(self.config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
......@@ -1157,8 +1184,8 @@ class ConfigurableTask(Task):
return result_dict
def compute_metric(self):
return self._metric_fn_list
def aggregation(self):
return self._aggregation_list
def higher_is_better(self):
return self._higher_is_better
......@@ -1204,7 +1231,7 @@ class MultipleChoiceTask(Task):
"acc_norm": True,
}
def compute_metric(self) -> dict:
def aggregation(self) -> dict:
return {
"acc": mean,
"acc_norm": mean,
......@@ -1265,7 +1292,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_),
}
def compute_metric(self) -> dict:
def aggregation(self) -> dict:
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment