Commit 2a9da9fb authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

add metric + agg registries

parent 460584ca
from . import metrics
METRIC_REGISTRY = {
"matthews_corrcoef": metrics.matthews_corrcoef,
"f1_score": metrics.f1_score,
"perplexity": metrics.perplexity,
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
AGGREGATION_REGISTRY = {
"mean": metrics.mean,
"median": metrics.median
}
\ No newline at end of file
...@@ -6,7 +6,67 @@ import sacrebleu ...@@ -6,7 +6,67 @@ import sacrebleu
import sklearn.metrics import sklearn.metrics
import random import random
import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library...")
try:
metric_object = evaluate.load(name)
return metric_object.compute
except:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
@register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -25,10 +85,12 @@ def mean_stderr(arr): ...@@ -25,10 +85,12 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median")
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -36,6 +98,7 @@ def matthews_corrcoef(items): ...@@ -36,6 +98,7 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds) return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score")
def f1_score(items): def f1_score(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -91,6 +154,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -91,6 +154,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
@register_metric("perplexity")
def perplexity(items): def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
...@@ -100,6 +164,7 @@ def weighted_mean(items): ...@@ -100,6 +164,7 @@ def weighted_mean(items):
return sum(a) / sum(b) return sum(a) / sum(b)
@register_metric("weighted_perplexity")
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
...@@ -108,6 +173,7 @@ def bits_per_byte(items): ...@@ -108,6 +173,7 @@ def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
...@@ -125,6 +191,7 @@ def bleu(items): ...@@ -125,6 +191,7 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf")
def chrf(items): def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output """chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams. based on character n-gram precision and recall enhanced with word n-grams.
...@@ -139,6 +206,7 @@ def chrf(items): ...@@ -139,6 +206,7 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter")
def ter(items): def ter(items):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
......
...@@ -11,9 +11,8 @@ import numpy as np ...@@ -11,9 +11,8 @@ import numpy as np
from typing import List, Union from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils from lm_eval import utils
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
...@@ -32,8 +31,8 @@ class TaskConfig(dict): ...@@ -32,8 +31,8 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = "" template_aliases: str = ""
doc_to_text: str = None doc_to_text: str = ""
doc_to_target: str = None doc_to_target: str = ""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl. # aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None # higher_is_better: dict = None
...@@ -111,7 +110,7 @@ class Task(abc.ABC): ...@@ -111,7 +110,7 @@ class Task(abc.ABC):
self._fewshot_docs = None self._fewshot_docs = None
self._instances = None self._instances = None
self._config = TaskConfig(**config) if config else {} self._config = TaskConfig(**config) if config else TaskConfig()
if not hasattr(self, "_filters"): if not hasattr(self, "_filters"):
self._filters = [] self._filters = []
...@@ -392,20 +391,23 @@ class ConfigurableTask(Task): ...@@ -392,20 +391,23 @@ class ConfigurableTask(Task):
self._higher_is_better = {} self._higher_is_better = {}
for (metric_name, aggregation, higher_is_better) in self._config.metric_list: for (metric_name, aggregation, higher_is_better) in self._config.metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._aggregation_list[metric_name] = get_aggregation(aggregation)
self._higher_is_better[metric_name] = higher_is_better self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys(): self._metric_list[metric_name] = get_metric(metric_name)
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
try: # if metric_name in METRIC_REGISTRY.keys():
metric_object = evaluate.load(metric_name) # self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._metric_list[metric_name] = metric_object # else:
except Exception as ex: # try:
raise Warning( # metric_object = evaluate.load(metric_name)
"{} not found in the evaluate library!".format(metric_name), # self._metric_list[metric_name] = metric_object
"Please check https://huggingface.co/evaluate-metric", # except Exception as ex:
) # raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs = None
...@@ -478,7 +480,7 @@ class ConfigurableTask(Task): ...@@ -478,7 +480,7 @@ class ConfigurableTask(Task):
result_dict = {} result_dict = {}
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute( _dict = self._metric_list[key](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
) )
...@@ -493,7 +495,7 @@ class ConfigurableTask(Task): ...@@ -493,7 +495,7 @@ class ConfigurableTask(Task):
def higher_is_better(self): def higher_is_better(self):
return self._higher_is_better_list return self._higher_is_better
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
...@@ -659,6 +661,7 @@ def get_task_name_from_object(task_object): ...@@ -659,6 +661,7 @@ def get_task_name_from_object(task_object):
if class_ is task_object: if class_ is task_object:
return name return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return ( return (
task_object.EVAL_HARNESS_NAME task_object.EVAL_HARNESS_NAME
......
...@@ -5,14 +5,4 @@ from . import gpt3 ...@@ -5,14 +5,4 @@ from . import gpt3
from . import textsynth from . import textsynth
from . import dummy from . import dummy
# MODEL_REGISTRY = {} # TODO: implement __all__
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
# def get_model(model_name):
# return MODEL_REGISTRY[model_name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment