"mmdet3d/vscode:/vscode.git/clone" did not exist on "864ed34f565ea5c066778c3c1aa708903ec22be4"
Commit b3591562 authored by lintangsutawika's avatar lintangsutawika
Browse files

metrics are now in a special folder so that registry can work better

parent 48344fcb
import math import math
from collections.abc import Iterable
import numpy as np
import sacrebleu
import sklearn.metrics
import random import random
import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
"acc_mutual_info": None,
"word_perplexity": None,
"byte_perplexity": None,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
def pop_stddev(arr): def pop_stddev(arr):
mu = mean(arr) mu = sum(arr) / len(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr): def sample_stddev(arr):
mu = mean(arr) mu = sum(arr) / len(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
...@@ -110,48 +16,6 @@ def mean_stderr(arr): ...@@ -110,48 +16,6 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def acc_all_stderr(items): def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question # Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {} question_scoring_dict = {}
...@@ -179,113 +43,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -179,113 +43,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
@register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
@register_metric("weighted_perplexity")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
# stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
...@@ -330,25 +87,6 @@ def bootstrap_stderr(f, xs, iters): ...@@ -330,25 +87,6 @@ def bootstrap_stderr(f, xs, iters):
return sample_stddev(res) return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def yesno(x): def yesno(x):
if x: if x:
return "yes" return "yes"
......
from .aggregation import *
from .metric import *
from lm_eval.api.metrics import bootstrap_stderr, mean_stderr, acc_all_stderr
from lm_eval.api.register import (
metric_registry,
aggregation_registry,
higher_is_better_registry,
output_type_registry,
default_aggregation_registry,
)
METRIC_REGISTRY = metric_registry
OUTPUT_TYPE_REGISTRY = output_type_registry
AGGREGATION_REGISTRY = aggregation_registry
DEFAULT_AGGREGATION_REGISTRY = default_aggregation_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": [
"acc",
],
"greedy_until": ["exact_match"],
}
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
import evaluate
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
"median",
"matthews_corrcoef",
"f1_score",
"perplexity",
"bleu",
"chrf",
"ter",
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(
METRIC_REGISTRY[metric], x, iters=bootstrap_iters
)
stderr = {"mean": mean_stderr, "acc_all": acc_all_stderr}
return stderr.get(metric, None)
import math
from lm_eval.api.register import register_aggregation
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
import math
from collections.abc import Iterable
import numpy as np
import sacrebleu
import sklearn.metrics
import random
from lm_eval.api.register import (
register_metric,
register_higher_is_better,
register_output_type,
register_default_aggregation,
)
@register_default_aggregation("mean")
@register_output_type("loglikelihood")
@register_output_type("multiple_choice")
@register_higher_is_better(True)
@register_metric("acc")
def acc_fn(items): # This is a passthrough function
return items
@register_default_aggregation("mean")
@register_output_type("multiple_choice")
@register_higher_is_better(True)
@register_metric("acc_norm")
def acc_norm_fn(items): # This is a passthrough function
return items
@register_default_aggregation("mean")
@register_output_type("multiple_choice")
@register_higher_is_better(True)
@register_metric("acc_mutual_info")
def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_default_aggregation("perplexity")
@register_output_type("loglikelihood")
@register_higher_is_better(False)
@register_metric("perplexity")
def perplexity_fn(items): # This is a passthrough function
return items
@register_default_aggregation("weighted_perplexity")
@register_output_type("loglikelihood_rolling")
@register_higher_is_better(False)
@register_metric("word_perplexity")
def word_perplexity_fn(items): # This is a passthrough function
return items
@register_default_aggregation("weighted_perplexity")
@register_output_type("loglikelihood_rolling")
@register_higher_is_better(False)
@register_metric("byte_perplexity")
def byte_perplexity_fn(items): # This is a passthrough function
return items
@register_default_aggregation("bits_per_byte")
@register_output_type("loglikelihood_rolling")
@register_higher_is_better(False)
@register_metric("bits_per_byte")
def bits_per_byte_fn(items): # This is a passthrough function
return items
@register_default_aggregation("mean")
@register_output_type("loglikelihood")
@register_higher_is_better(True)
@register_metric("acc_all")
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
@register_default_aggregation("mean")
@register_higher_is_better(True)
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_default_aggregation("mean")
@register_higher_is_better(True)
@register_metric("f1")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
@register_default_aggregation("mean")
@register_higher_is_better(True)
@register_metric("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_default_aggregation("mean")
@register_higher_is_better(True)
@register_metric("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_default_aggregation("mean")
@register_higher_is_better(False)
@register_metric("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment