import logging import math import random from collections.abc import Iterable import evaluate import numpy as np import sacrebleu import sklearn.metrics from lm_eval.api.registry import register_aggregation, register_metric eval_logger = logging.getLogger("lm-eval") # Register Aggregations First @register_aggregation("mean") def mean(arr): return sum(arr) / len(arr) @register_aggregation("median") def median(arr): return arr[len(arr) // 2] # Certain metrics must be calculated across all documents in a benchmark. # We use them as aggregation metrics, paired with no-op passthrough metric fns. @register_aggregation("perplexity") def perplexity(items): return math.exp(-mean(items)) @register_aggregation("weighted_perplexity") def weighted_perplexity(items): return math.exp(-weighted_mean(items)) @register_aggregation("bits_per_byte") def bits_per_byte(items): return -weighted_mean(items) / math.log(2) @register_aggregation("f1") def f1_score(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] fscore = sklearn.metrics.f1_score(golds, preds) return np.max(fscore) @register_aggregation("matthews_corrcoef") def matthews_corrcoef(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] # print(preds) return sklearn.metrics.matthews_corrcoef(golds, preds) @register_aggregation("bleu") def bleu(items): """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric for evaluating a generated sentence to a reference sentence. It counts matching n-grams in the candidate translation to n-grams in the reference text, where 1-gram or unigram would be each token and a bigram comparison would be each word pair. The comparison is made regardless of word order Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ Paper: https://www.aclweb.org/anthology/P02-1040/ Higher is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_bleu(preds, refs).score @register_aggregation("chrf") def chrf(items): """chrF++ is a tool for automatic evaluation of machine translation output based on character n-gram precision and recall enhanced with word n-grams. Source: https://github.com/m-popovic/chrF Paper: https://www.aclweb.org/anthology/W15-3049.pdf Higher is better # TODO I think """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_chrf(preds, refs).score @register_aggregation("ter") def ter(items): """Translation Error Rate is an error metric for machine translation that measures the number of edits required to change a system output into one of the references Source: http://www.cs.umd.edu/~snover/tercom/ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf Lower is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_ter(preds, refs).score @register_aggregation("ece") def ece(items: list) -> float: probs: list[float] = [] scores: list[float] = [] for i in range(len(items)): # Get only largest probability from each example largest_idx = np.argmax(items[i]["probs"]) probs.append(items[i]["probs"][largest_idx]) scores.append(items[i]["scores"][largest_idx]) sorted_indices = np.argsort(probs) sorted_probs = np.asarray(probs)[sorted_indices] sorted_scores = np.asarray(scores)[sorted_indices] def bin_to_subsets(array: np.ndarray, num_subsets: int = 10) -> np.ndarray: subset_size: int = len(array) // num_subsets remainder: int = len(array) % num_subsets subsets: list[np.ndarray] = [] start: int = 0 for _ in range(num_subsets): subset_end: int = start + subset_size + (1 if remainder > 0 else 0) subsets.append(array[start:subset_end]) start = subset_end remainder -= 1 return subsets probs = np.asarray([np.mean(x) for x in bin_to_subsets(sorted_probs, 10)]) freqs = np.asarray([np.mean(x) for x in bin_to_subsets(sorted_scores, 10)]) return np.sum(np.abs(freqs - probs)) / len(freqs) @register_metric( metric="acc", higher_is_better=True, output_type=["loglikelihood", "multiple_choice"], aggregation="mean", ) def acc_fn(items): # This is a passthrough function return items @register_metric( metric="acc_norm", higher_is_better=True, output_type=["loglikelihood", "multiple_choice"], aggregation="mean", ) def acc_norm_fn(items): # This is a passthrough function return items @register_metric( metric="acc_mutual_info", higher_is_better=True, output_type="multiple_choice", aggregation="mean", ) def acc_mutual_info_fn(items): # This is a passthrough function return items exact_match = evaluate.load("exact_match") @register_metric( metric="exact_match", higher_is_better=True, output_type="generate_until", aggregation="mean", ) def exact_match_fn(**kwargs): return exact_match.compute(**kwargs) @register_metric( metric="ece", higher_is_better=False, output_type="multiple_choice", aggregation="ece", ) def ece_fn(items): # This is a passthrough function """ Expected Calibration Error (ECE). This consists of the average absolute difference between the fraction of model predictions which are correct and the mean of the model's normalized probability for those predictions (after binning), for multiple choice questions. Paper: https://arxiv.org/abs/2207.05221 """ return items @register_metric( metric="perplexity", higher_is_better=False, output_type="loglikelihood", aggregation="perplexity", ) def perplexity_fn(items): # This is a passthrough function return items @register_metric( metric="word_perplexity", higher_is_better=False, output_type="loglikelihood_rolling", aggregation="weighted_perplexity", ) def word_perplexity_fn(items): # This is a passthrough function return items @register_metric( metric="byte_perplexity", higher_is_better=False, output_type="loglikelihood_rolling", aggregation="weighted_perplexity", ) def byte_perplexity_fn(items): # This is a passthrough function return items @register_metric( metric="bits_per_byte", higher_is_better=False, output_type="loglikelihood_rolling", aggregation="bits_per_byte", ) def bits_per_byte_fn(items): # This is a passthrough function return items def pop_stddev(arr): mu = mean(arr) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) def sample_stddev(arr): mu = mean(arr) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) def mean_stderr(arr): return sample_stddev(arr) / math.sqrt(len(arr)) @register_metric( metric="mcc", higher_is_better=True, output_type="multiple_choice", aggregation="matthews_corrcoef", ) def mcc_fn(items): # This is a passthrough function return items @register_metric( metric="f1", higher_is_better=True, output_type="multiple_choice", aggregation="f1", ) def f1_fn(items): # This is a passthrough function return items @register_metric( metric="bleu", higher_is_better=True, output_type="generate_until", aggregation="bleu", ) def bleu_fn(items): # This is a passthrough function return items @register_metric( metric="chrf", higher_is_better=True, output_type="generate_until", aggregation="chrf", ) def chrf_fn(items): # This is a passthrough function return items @register_metric( metric="ter", higher_is_better=True, output_type="generate_until", aggregation="ter", ) def ter_fn(items): # This is a passthrough function return items @register_metric( metric="acc_all", higher_is_better=True, output_type="loglikelihood", aggregation="mean", ) def acc_all(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {} preds = list(zip(*items))[0] docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): paragraph_id = doc["idx"]["paragraph"] question_id = doc["idx"]["question"] if (paragraph_id, question_id) not in question_scoring_dict: question_scoring_dict[(paragraph_id, question_id)] = [] gold_label = doc["label"] == 1 question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) return acc def acc_all_stderr(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {} preds = list(zip(*items))[0] docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): question_id = doc["idx"]["question"] if question_id not in question_scoring_dict: question_scoring_dict[question_id] = [] gold_label = doc["label"] == 1 question_scoring_dict[question_id].append(gold_label == pred) acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) return acc def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): """Compute max metric between prediction and each ground truth.""" scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def weighted_mean(items): a, b = zip(*items) return sum(a) / sum(b) def is_non_str_iterable(obj): return isinstance(obj, Iterable) and not isinstance(obj, str) def _sacreformat(refs, preds): """Format refs and preds for sacrebleu corpus calculation. It is very particular""" # Sacrebleu expects (List[str], List[List[str]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # Note [ref1_stream] is the first reference for each pred. # So lists are size N and (M, N) for N preds and M possible refs for each pred # This is a different order of dimensions that I would expect # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds if not is_non_str_iterable(refs): refs = list(refs) if not is_non_str_iterable(refs[0]): refs = [[ref] for ref in refs] refs = list(zip(*refs)) # Note the number of refs in each ref list much match the number of preds # We expect preds to be List[str] or List[List[str]]. Must become List[str] if not is_non_str_iterable(preds): preds = list(preds) if is_non_str_iterable(preds[0]): assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" preds = [pred[0] for pred in preds] return refs, preds # stderr stuff class _bootstrap_internal: def __init__(self, f, n) -> None: self.f = f self.n = n def __call__(self, v): i, xs = v rnd = random.Random() rnd.seed(i) res = [] for _ in range(self.n): res.append(self.f(rnd.choices(xs, k=len(xs)))) return res def bootstrap_stderr(f, xs, iters): import multiprocessing as mp pool = mp.Pool(mp.cpu_count()) # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # equivalent to stderr calculated without Bessel's correction in the stddev. # Unfortunately, I haven't been able to figure out what the right correction is # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) # Thankfully, shouldn't matter because our samples are pretty big usually anyways res = [] chunk_size = min(1000, iters) from tqdm import tqdm print("bootstrapping for stddev:", f.__name__) for bootstrap in tqdm( pool.imap( _bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)], ), total=iters // chunk_size, ): # sample w replacement res.extend(bootstrap) pool.close() return sample_stddev(res) def stderr_for_metric(metric, bootstrap_iters): bootstrappable = [ median, matthews_corrcoef, f1_score, perplexity, bleu, chrf, ter, ] if metric in bootstrappable: return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) stderr = {mean: mean_stderr, acc_all: acc_all_stderr} return stderr.get(metric, None)