Commit aaf64aab authored by lintangsutawika's avatar lintangsutawika
Browse files

readded suport for aggregation

parent 439dca55
...@@ -8,20 +8,23 @@ import numpy as np ...@@ -8,20 +8,23 @@ import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
from lm_eval.api.registry import register_metric from lm_eval.api.registry import register_metric, register_aggregation
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
@register_aggregation("weighted_mean")
def weighted_mean(items): def weighted_mean(items):
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
...@@ -161,6 +164,7 @@ def acc_mutual_info_fn(items): ...@@ -161,6 +164,7 @@ def acc_mutual_info_fn(items):
exact_match = evaluate.load("exact_match") exact_match = evaluate.load("exact_match")
@register_metric( @register_metric(
metric="exact_match", metric="exact_match",
higher_is_better=True, higher_is_better=True,
......
import os import os
import logging import logging
import evaluate import evaluate
import collections
from functools import partial from functools import partial
from lm_eval.api.model import LM from lm_eval.api.model import LM
...@@ -9,21 +10,6 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -9,21 +10,6 @@ eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(
references=refs,
predictions=preds
)[self.name]
def register_model(*names): def register_model(*names):
# either pass a list or a single alias. # either pass a list or a single alias.
...@@ -87,8 +73,8 @@ def register_group(name): ...@@ -87,8 +73,8 @@ def register_group(name):
return decorate return decorate
METRIC_FUNCTION_REGISTRY = {} METRIC_REGISTRY = collections.defaultdict(dict)
HIGHER_IS_BETTER_REGISTRY = {} AGGREGATION_REGISTRY = collections.defaultdict(dict)
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [], "loglikelihood": [],
...@@ -102,6 +88,7 @@ def register_metric( ...@@ -102,6 +88,7 @@ def register_metric(
metric=None, metric=None,
higher_is_better=None, higher_is_better=None,
output_type=None, output_type=None,
aggregation=None,
): ):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
...@@ -112,10 +99,13 @@ def register_metric( ...@@ -112,10 +99,13 @@ def register_metric(
metric_list = metric metric_list = metric
for _metric in metric_list: for _metric in metric_list:
METRIC_FUNCTION_REGISTRY[_metric] = fn METRIC_REGISTRY[_metric]["function"] = fn
if aggregation is not None:
METRIC_REGISTRY[_metric]["aggregation"] = aggregation
if higher_is_better is not None: if higher_is_better is not None:
HIGHER_IS_BETTER_REGISTRY[_metric] = higher_is_better METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
if output_type is not None: if output_type is not None:
if type(output_type) == str: if type(output_type) == str:
...@@ -131,18 +121,33 @@ def register_metric( ...@@ -131,18 +121,33 @@ def register_metric(
return decorate return decorate
def get_metric(name, hf_evaluate_metric=False, **kwargs): def get_metric(name):
if not hf_evaluate_metric: if name in METRIC_REGISTRY:
if name in METRIC_FUNCTION_REGISTRY: return METRIC_REGISTRY[name]
return METRIC_FUNCTION_REGISTRY[name] else:
else: eval_logger.error(f"Could not find registered metric '{name}' in lm-eval")
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
) def get_evaluate(name, **kwargs):
try: try:
# from lm_eval.metrics import HFEvaluateAdaptor
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(references=refs, predictions=preds)[
self.name
]
return HFEvaluateAdaptor(name, **kwargs) return HFEvaluateAdaptor(name, **kwargs)
except Exception: except Exception:
eval_logger.error( eval_logger.error(
...@@ -150,10 +155,22 @@ def get_metric(name, hf_evaluate_metric=False, **kwargs): ...@@ -150,10 +155,22 @@ def get_metric(name, hf_evaluate_metric=False, **kwargs):
) )
def is_higher_better(metric_name): def register_aggregation(name):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning( eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!" "{} not a registered aggregation metric!".format(name),
) )
...@@ -32,7 +32,9 @@ from lm_eval.api.metrics import ( ...@@ -32,7 +32,9 @@ from lm_eval.api.metrics import (
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, get_metric,
is_higher_better, get_evaluate,
get_aggregation,
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
) )
...@@ -410,7 +412,7 @@ class Task(abc.ABC): ...@@ -410,7 +412,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def compute_metric(self): def aggregation(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
...@@ -553,6 +555,7 @@ class ConfigurableTask(Task): ...@@ -553,6 +555,7 @@ class ConfigurableTask(Task):
self._metric_fn_list = {} self._metric_fn_list = {}
self._metric_fn_kwargs = {} self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
if self.config.metric_list is None: if self.config.metric_list is None:
...@@ -561,12 +564,14 @@ class ConfigurableTask(Task): ...@@ -561,12 +564,14 @@ class ConfigurableTask(Task):
for metric_name in _metric_list: for metric_name in _metric_list:
metric = get_metric(metric_name) metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric self._metric_fn_list[metric_name] = metric["function"]
self._metric_fn_kwargs[metric_name] = {} self._metric_fn_kwargs[metric_name] = {}
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._aggregation_list = metric["aggregation"]
self._higher_is_better[metric_name] = metric["is_higher_better"]
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
from_registry = False
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -574,25 +579,47 @@ class ConfigurableTask(Task): ...@@ -574,25 +579,47 @@ class ConfigurableTask(Task):
if key if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
} }
hf_evaluate_metric = ( use_hf_evaluate = (
"hf_evaluate" in metric_config "hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True and metric_config["hf_evaluate"] is True
) )
# if self.config.process_results is not None:
# self._metric_fn_list[metric_name] = None
# self._metric_fn_kwargs[metric_name] = {}
if callable(metric_name): if callable(metric_name):
metric_fn = metric_name.__call__ metric_fn = metric_name.__call__
metric_name = metric_name.__name__ metric_name = metric_name.__name__
else: else:
metric_fn = get_metric( assert type(metric_name) == str
metric_name, hf_evaluate_metric, **kwargs if use_hf_evaluate:
) metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY:
from_registry = True
metric = get_metric(metric_name, **kwargs)
metric_fn = metric["function"]
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = metric_fn self._metric_fn_list[metric_name] = metric_fn
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = agg_name
else:
if from_registry:
if "aggregation" in metric:
self._aggregation_list[metric_name] = metric["aggregation"]
else:
self._aggregation_list[metric_name] = metric_fn
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
if from_registry:
self._higher_is_better[metric_name] = metric["higher_is_better"]
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
...@@ -1157,8 +1184,8 @@ class ConfigurableTask(Task): ...@@ -1157,8 +1184,8 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def compute_metric(self): def aggregation(self):
return self._metric_fn_list return self._aggregation_list
def higher_is_better(self): def higher_is_better(self):
return self._higher_is_better return self._higher_is_better
...@@ -1204,7 +1231,7 @@ class MultipleChoiceTask(Task): ...@@ -1204,7 +1231,7 @@ class MultipleChoiceTask(Task):
"acc_norm": True, "acc_norm": True,
} }
def compute_metric(self) -> dict: def aggregation(self) -> dict:
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1265,7 +1292,7 @@ class PerplexityTask(Task): ...@@ -1265,7 +1292,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, bytes_),
} }
def compute_metric(self) -> dict: def aggregation(self) -> dict:
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment