Commit 2a9da9fb authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

add metric + agg registries

parent 460584ca
from . import metrics
METRIC_REGISTRY = {
"matthews_corrcoef": metrics.matthews_corrcoef,
"f1_score": metrics.f1_score,
"perplexity": metrics.perplexity,
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
AGGREGATION_REGISTRY = {
"mean": metrics.mean,
"median": metrics.median
}
\ No newline at end of file
......@@ -6,7 +6,67 @@ import sacrebleu
import sklearn.metrics
import random
import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library...")
try:
metric_object = evaluate.load(name)
return metric_object.compute
except:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
......@@ -25,10 +85,12 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
......@@ -36,6 +98,7 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
......@@ -91,6 +154,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
@register_metric("perplexity")
def perplexity(items):
return math.exp(-mean(items))
......@@ -100,6 +164,7 @@ def weighted_mean(items):
return sum(a) / sum(b)
@register_metric("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
......@@ -108,6 +173,7 @@ def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......@@ -125,6 +191,7 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
......@@ -139,6 +206,7 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
......
......@@ -11,9 +11,8 @@ import numpy as np
from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
from lm_eval.filters import build_filter_ensemble
......@@ -32,8 +31,8 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = ""
doc_to_text: str = None
doc_to_target: str = None
doc_to_text: str = ""
doc_to_target: str = ""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
......@@ -111,7 +110,7 @@ class Task(abc.ABC):
self._fewshot_docs = None
self._instances = None
self._config = TaskConfig(**config) if config else {}
self._config = TaskConfig(**config) if config else TaskConfig()
if not hasattr(self, "_filters"):
self._filters = []
......@@ -392,20 +391,23 @@ class ConfigurableTask(Task):
self._higher_is_better = {}
for (metric_name, aggregation, higher_is_better) in self._config.metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._aggregation_list[metric_name] = get_aggregation(aggregation)
self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
except Exception as ex:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
self._metric_list[metric_name] = get_metric(metric_name)
# if metric_name in METRIC_REGISTRY.keys():
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
# else:
# try:
# metric_object = evaluate.load(metric_name)
# self._metric_list[metric_name] = metric_object
# except Exception as ex:
# raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None
......@@ -478,7 +480,7 @@ class ConfigurableTask(Task):
result_dict = {}
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
_dict = self._metric_list[key](
references=[gold],
predictions=[result],
)
......@@ -493,7 +495,7 @@ class ConfigurableTask(Task):
def higher_is_better(self):
return self._higher_is_better_list
return self._higher_is_better
class MultipleChoiceTask(Task):
......@@ -659,6 +661,7 @@ def get_task_name_from_object(task_object):
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
......
......@@ -5,14 +5,4 @@ from . import gpt3
from . import textsynth
from . import dummy
# MODEL_REGISTRY = {}
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
# def get_model(model_name):
# return MODEL_REGISTRY[model_name]
# TODO: implement __all__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment