Commit 039832e5 authored by lintangsutawika's avatar lintangsutawika
Browse files

removed passthrough fn

parent 3888193d
import logging
import math
import random
from collections.abc import Iterable
import abc
import evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation
from lm_eval.api.registry import register_metric
import logging
eval_logger = logging.getLogger("lm-eval")
class BaseMetric:
def __init__(
self,
) -> None:
@abc.abstractmethod
def update(self, *items):
pass
@abc.abstractmethod
def compute(self, *items):
pass
def mean(arr):
return sum(arr) / len(arr)
......@@ -37,32 +22,43 @@ def median(arr):
return arr[len(arr) // 2]
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
)
class PerplexityMetric(BaseMetric):
def update(self, ll, is_greedy):
return ll
def perplexity(items):
return math.exp(-mean(items))
def compute(self, items):
return math.exp(-mean(items))
@register_metric(
metric=["word_perplexity", "byte_perplexity"],
higher_is_better=False,
output_type="loglikelihood_rolling",
)
def weighted_perplexity(items): # This is a passthrough function
return math.exp(-weighted_mean(items))
@register_metric(
metric="acc",
higher_is_better=True,
output_type="loglikelihood",
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
)
class LoglikelihoodAccMetric(BaseMetric):
def update(self, ll, is_greedy):
return int(is_greedy)
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
def compute(self, items):
return math.exp(-mean(items))
@register_aggregation("f1")
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
......@@ -72,16 +68,23 @@ def f1_score(items):
return np.max(fscore)
@register_aggregation("matthews_corrcoef")
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
)
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_aggregation("bleu")
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......@@ -99,7 +102,11 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf")
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
)
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
......@@ -114,7 +121,11 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter")
@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
)
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
......@@ -130,86 +141,34 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score
# @register_metric(
# metric="acc",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_norm",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_norm_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_mutual_info",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="mean",
# )
# def acc_mutual_info_fn(items): # This is a passthrough function
# return items
exact_match = evaluate.load("exact_match")
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
metric=["acc", "acc_norm"],
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
)
class BytePerplexityMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _words
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
def aggregate_acc_fn(items):
return mean(items)
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
)
class BytePerplexityMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def acc_mutual_info_fn(items):
return mean(items)
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
exact_match = evaluate.load("exact_match")
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
)
class BitsPerByteMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def set_wise_compute(self, items):
return -weighted_mean(items) / math.log(2)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
def pop_stddev(arr):
......@@ -226,79 +185,28 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
# @register_metric(
# metric="mcc",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="matthews_corrcoef",
# )
# def mcc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="f1",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="f1",
# )
# def f1_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="bleu",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="bleu",
# )
# def bleu_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="chrf",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="chrf",
# )
# def chrf_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="ter",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="ter",
# )
# def ter_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_all",
# higher_is_better=True,
# output_type="loglikelihood",
# aggregation="mean",
# )
# def acc_all(items):
# # Only count as correct if all answers are labeled correctly for each question
# question_scoring_dict = {}
# preds = list(zip(*items))[0]
# docs = list(zip(*items))[1]
# for doc, pred in zip(docs, preds):
# paragraph_id = doc["idx"]["paragraph"]
# question_id = doc["idx"]["question"]
# if (paragraph_id, question_id) not in question_scoring_dict:
# question_scoring_dict[(paragraph_id, question_id)] = []
# gold_label = doc["label"] == 1
# question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# return acc
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def acc_all_stderr(items):
......@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment