Commit 039832e5 authored by lintangsutawika's avatar lintangsutawika
Browse files

removed passthrough fn

parent 3888193d
import logging
import math import math
import random
from collections.abc import Iterable from collections.abc import Iterable
import abc import evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_metric
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
class BaseMetric:
def __init__(
self,
) -> None:
@abc.abstractmethod
def update(self, *items):
pass
@abc.abstractmethod
def compute(self, *items):
pass
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -37,32 +22,43 @@ def median(arr): ...@@ -37,32 +22,43 @@ def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
@register_metric( @register_metric(
metric="perplexity", metric="perplexity",
higher_is_better=False, higher_is_better=False,
output_type="loglikelihood", output_type="loglikelihood",
) )
class PerplexityMetric(BaseMetric): def perplexity(items):
def update(self, ll, is_greedy): return math.exp(-mean(items))
return ll
def compute(self, items):
return math.exp(-mean(items)) @register_metric(
metric=["word_perplexity", "byte_perplexity"],
higher_is_better=False,
output_type="loglikelihood_rolling",
)
def weighted_perplexity(items): # This is a passthrough function
return math.exp(-weighted_mean(items))
@register_metric( @register_metric(
metric="acc", metric="bits_per_byte",
higher_is_better=True, higher_is_better=False,
output_type="loglikelihood", output_type="loglikelihood_rolling",
) )
class LoglikelihoodAccMetric(BaseMetric): def bits_per_byte(items):
def update(self, ll, is_greedy): return -weighted_mean(items) / math.log(2)
return int(is_greedy)
def compute(self, items):
return math.exp(-mean(items))
@register_aggregation("f1") @register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
)
def f1_score(items): def f1_score(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -72,16 +68,23 @@ def f1_score(items): ...@@ -72,16 +68,23 @@ def f1_score(items):
return np.max(fscore) return np.max(fscore)
@register_aggregation("matthews_corrcoef") @register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
)
def matthews_corrcoef(items): def matthews_corrcoef(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
preds = unzipped_list[1] preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds) return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_aggregation("bleu") @register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
)
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
...@@ -99,7 +102,11 @@ def bleu(items): ...@@ -99,7 +102,11 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf") @register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
)
def chrf(items): def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output """chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams. based on character n-gram precision and recall enhanced with word n-grams.
...@@ -114,7 +121,11 @@ def chrf(items): ...@@ -114,7 +121,11 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter") @register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
)
def ter(items): def ter(items):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
...@@ -130,86 +141,34 @@ def ter(items): ...@@ -130,86 +141,34 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
# @register_metric(
# metric="acc",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_norm",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_norm_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_mutual_info",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="mean",
# )
# def acc_mutual_info_fn(items): # This is a passthrough function
# return items
exact_match = evaluate.load("exact_match")
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@register_metric( @register_metric(
metric="word_perplexity", metric=["acc", "acc_norm"],
higher_is_better=False, higher_is_better=True,
output_type="loglikelihood_rolling", output_type=["loglikelihood", "multiple_choice"],
) )
class BytePerplexityMetric(BaseMetric): def aggregate_acc_fn(items):
def sample_wise_compute(self, loglikelihood, _words, _bytes): return mean(items)
return loglikelihood, _words
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
@register_metric( @register_metric(
metric="byte_perplexity", metric="acc_mutual_info",
higher_is_better=False, higher_is_better=True,
output_type="loglikelihood_rolling", output_type="multiple_choice",
) )
class BytePerplexityMetric(BaseMetric): def acc_mutual_info_fn(items):
def sample_wise_compute(self, loglikelihood, _words, _bytes): return mean(items)
return loglikelihood, _bytes
def set_wise_compute(self, items): exact_match = evaluate.load("exact_match")
return math.exp(-weighted_mean(items))
@register_metric( @register_metric(
metric="bits_per_byte", metric="exact_match",
higher_is_better=False, higher_is_better=True,
output_type="loglikelihood_rolling", output_type="generate_until",
) )
class BitsPerByteMetric(BaseMetric): def exact_match_fn(**kwargs):
def sample_wise_compute(self, loglikelihood, _words, _bytes): return exact_match.compute(**kwargs)
return loglikelihood, _bytes
def set_wise_compute(self, items):
return -weighted_mean(items) / math.log(2)
def pop_stddev(arr): def pop_stddev(arr):
...@@ -226,79 +185,28 @@ def mean_stderr(arr): ...@@ -226,79 +185,28 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
# @register_metric( @register_metric(
# metric="mcc", metric="acc_all",
# higher_is_better=True, higher_is_better=True,
# output_type="multiple_choice", output_type="loglikelihood",
# aggregation="matthews_corrcoef", )
# ) def acc_all(items):
# def mcc_fn(items): # This is a passthrough function # Only count as correct if all answers are labeled correctly for each question
# return items question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
# @register_metric(
# metric="f1", for doc, pred in zip(docs, preds):
# higher_is_better=True, paragraph_id = doc["idx"]["paragraph"]
# output_type="multiple_choice", question_id = doc["idx"]["question"]
# aggregation="f1", if (paragraph_id, question_id) not in question_scoring_dict:
# ) question_scoring_dict[(paragraph_id, question_id)] = []
# def f1_fn(items): # This is a passthrough function
# return items gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# @register_metric( acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# metric="bleu", return acc
# higher_is_better=True,
# output_type="generate_until",
# aggregation="bleu",
# )
# def bleu_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="chrf",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="chrf",
# )
# def chrf_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="ter",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="ter",
# )
# def ter_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_all",
# higher_is_better=True,
# output_type="loglikelihood",
# aggregation="mean",
# )
# def acc_all(items):
# # Only count as correct if all answers are labeled correctly for each question
# question_scoring_dict = {}
# preds = list(zip(*items))[0]
# docs = list(zip(*items))[1]
# for doc, pred in zip(docs, preds):
# paragraph_id = doc["idx"]["paragraph"]
# question_id = doc["idx"]["question"]
# if (paragraph_id, question_id) not in question_scoring_dict:
# question_scoring_dict[(paragraph_id, question_id)] = []
# gold_label = doc["label"] == 1
# question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# return acc
def acc_all_stderr(items): def acc_all_stderr(items):
...@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def is_non_str_iterable(obj): def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str) return isinstance(obj, Iterable) and not isinstance(obj, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment