import math from collections.abc import Iterable import numpy as np import sacrebleu import sklearn.metrics import random import evaluate from lm_eval.api.registry import register_metric, register_aggregation import logging eval_logger = logging.getLogger("lm-eval") class BaseMetric: def __init__( self, aggregation=None, ) -> None: self.aggregation = aggregation def __call__(self, *items): sample_wise_score = self.sample_wise_compute(*items) if self.aggregation is not None: return self.aggregation(sample_wise_score) else: return self.set_wise_compute(sample_wise_score) def sample_wise_compute(self, *items): return items def set_wise_compute(self, *items): return items # Register Aggregations First @register_aggregation("mean") def mean(arr): return sum(arr) / len(arr) @register_aggregation("median") def median(arr): return arr[len(arr) // 2] @register_metric( metric="perplexity", higher_is_better=False, output_type="loglikelihood", ) class PerplexityMetric(BaseMetric): def sample_wise_compute(self, ll, is_greedy): return ll def set_wise_compute(self, items): return math.exp(-mean(items)) @register_metric( metric="acc", higher_is_better=True, output_type="loglikelihood", aggregation="mean", ) class LoglikelihoodAccMetric(BaseMetric): def __call__(self, ll, is_greedy): return int(is_greedy) @register_aggregation("f1") def f1_score(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] fscore = sklearn.metrics.f1_score(golds, preds) return np.max(fscore) @register_aggregation("matthews_corrcoef") def matthews_corrcoef(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] # print(preds) return sklearn.metrics.matthews_corrcoef(golds, preds) @register_aggregation("bleu") def bleu(items): """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric for evaluating a generated sentence to a reference sentence. It counts matching n-grams in the candidate translation to n-grams in the reference text, where 1-gram or unigram would be each token and a bigram comparison would be each word pair. The comparison is made regardless of word order Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ Paper: https://www.aclweb.org/anthology/P02-1040/ Higher is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_bleu(preds, refs).score @register_aggregation("chrf") def chrf(items): """chrF++ is a tool for automatic evaluation of machine translation output based on character n-gram precision and recall enhanced with word n-grams. Source: https://github.com/m-popovic/chrF Paper: https://www.aclweb.org/anthology/W15-3049.pdf Higher is better # TODO I think """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_chrf(preds, refs).score @register_aggregation("ter") def ter(items): """Translation Error Rate is an error metric for machine translation that measures the number of edits required to change a system output into one of the references Source: http://www.cs.umd.edu/~snover/tercom/ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf Lower is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_ter(preds, refs).score # @register_metric( # metric="acc", # higher_is_better=True, # output_type=["loglikelihood", "multiple_choice"], # aggregation="mean", # ) # def acc_fn(items): # This is a passthrough function # return items # @register_metric( # metric="acc_norm", # higher_is_better=True, # output_type=["loglikelihood", "multiple_choice"], # aggregation="mean", # ) # def acc_norm_fn(items): # This is a passthrough function # return items # @register_metric( # metric="acc_mutual_info", # higher_is_better=True, # output_type="multiple_choice", # aggregation="mean", # ) # def acc_mutual_info_fn(items): # This is a passthrough function # return items exact_match = evaluate.load("exact_match") # @register_metric( # metric="exact_match", # higher_is_better=True, # output_type="generate_until", # aggregation="mean", # ) # def exact_match_fn(**kwargs): # return exact_match.compute(**kwargs) @register_metric( metric="word_perplexity", higher_is_better=False, output_type="loglikelihood_rolling", ) class BytePerplexityMetric(BaseMetric): def sample_wise_compute(self, loglikelihood, _words, _bytes): return loglikelihood, _words def set_wise_compute(self, items): return math.exp(-weighted_mean(items)) @register_metric( metric="byte_perplexity", higher_is_better=False, output_type="loglikelihood_rolling", ) class BytePerplexityMetric(BaseMetric): def sample_wise_compute(self, loglikelihood, _words, _bytes): return loglikelihood, _bytes def set_wise_compute(self, items): return math.exp(-weighted_mean(items)) @register_metric( metric="bits_per_byte", higher_is_better=False, output_type="loglikelihood_rolling", ) class BitsPerByteMetric(BaseMetric): def sample_wise_compute(self, loglikelihood, _words, _bytes): return loglikelihood, _bytes def set_wise_compute(self, items): return -weighted_mean(items) / math.log(2) def pop_stddev(arr): mu = mean(arr) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) def sample_stddev(arr): mu = mean(arr) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) def mean_stderr(arr): return sample_stddev(arr) / math.sqrt(len(arr)) # @register_metric( # metric="mcc", # higher_is_better=True, # output_type="multiple_choice", # aggregation="matthews_corrcoef", # ) # def mcc_fn(items): # This is a passthrough function # return items # @register_metric( # metric="f1", # higher_is_better=True, # output_type="multiple_choice", # aggregation="f1", # ) # def f1_fn(items): # This is a passthrough function # return items # @register_metric( # metric="bleu", # higher_is_better=True, # output_type="generate_until", # aggregation="bleu", # ) # def bleu_fn(items): # This is a passthrough function # return items # @register_metric( # metric="chrf", # higher_is_better=True, # output_type="generate_until", # aggregation="chrf", # ) # def chrf_fn(items): # This is a passthrough function # return items # @register_metric( # metric="ter", # higher_is_better=True, # output_type="generate_until", # aggregation="ter", # ) # def ter_fn(items): # This is a passthrough function # return items # @register_metric( # metric="acc_all", # higher_is_better=True, # output_type="loglikelihood", # aggregation="mean", # ) # def acc_all(items): # # Only count as correct if all answers are labeled correctly for each question # question_scoring_dict = {} # preds = list(zip(*items))[0] # docs = list(zip(*items))[1] # for doc, pred in zip(docs, preds): # paragraph_id = doc["idx"]["paragraph"] # question_id = doc["idx"]["question"] # if (paragraph_id, question_id) not in question_scoring_dict: # question_scoring_dict[(paragraph_id, question_id)] = [] # gold_label = doc["label"] == 1 # question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) # acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) # return acc def acc_all_stderr(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {} preds = list(zip(*items))[0] docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): question_id = doc["idx"]["question"] if question_id not in question_scoring_dict: question_scoring_dict[question_id] = [] gold_label = doc["label"] == 1 question_scoring_dict[question_id].append(gold_label == pred) acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) return acc def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): """Compute max metric between prediction and each ground truth.""" scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def weighted_mean(items): a, b = zip(*items) return sum(a) / sum(b) def is_non_str_iterable(obj): return isinstance(obj, Iterable) and not isinstance(obj, str) def _sacreformat(refs, preds): """Format refs and preds for sacrebleu corpus calculation. It is very particular""" # Sacrebleu expects (List[str], List[List[str]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # Note [ref1_stream] is the first reference for each pred. # So lists are size N and (M, N) for N preds and M possible refs for each pred # This is a different order of dimensions that I would expect # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds if not is_non_str_iterable(refs): refs = list(refs) if not is_non_str_iterable(refs[0]): refs = [[ref] for ref in refs] refs = list(zip(*refs)) # Note the number of refs in each ref list much match the number of preds # We expect preds to be List[str] or List[List[str]]. Must become List[str] if not is_non_str_iterable(preds): preds = list(preds) if is_non_str_iterable(preds[0]): assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" preds = [pred[0] for pred in preds] return refs, preds # stderr stuff class _bootstrap_internal: def __init__(self, f, n) -> None: self.f = f self.n = n def __call__(self, v): i, xs = v rnd = random.Random() rnd.seed(i) res = [] for _ in range(self.n): res.append(self.f(rnd.choices(xs, k=len(xs)))) return res def bootstrap_stderr(f, xs, iters): import multiprocessing as mp pool = mp.Pool(mp.cpu_count()) # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # equivalent to stderr calculated without Bessel's correction in the stddev. # Unfortunately, I haven't been able to figure out what the right correction is # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) # Thankfully, shouldn't matter because our samples are pretty big usually anyways res = [] chunk_size = min(1000, iters) from tqdm import tqdm print("bootstrapping for stddev:", f.__name__) for bootstrap in tqdm( pool.imap( _bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)], ), total=iters // chunk_size, ): # sample w replacement res.extend(bootstrap) pool.close() return sample_stddev(res) def stderr_for_metric(metric, bootstrap_iters): bootstrappable = [ median, matthews_corrcoef, f1_score, perplexity, bleu, chrf, ter, ] if metric in bootstrappable: return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) stderr = {mean: mean_stderr, acc_all: acc_all_stderr} return stderr.get(metric, None)