import math from collections import Iterable from pprint import pprint import numpy as np import sacrebleu import sklearn def mean(arr): return sum(arr) / len(arr) def median(arr): return arr[len(arr) // 2] def matthews_corrcoef(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] return sklearn.metrics.matthews_corrcoef(golds, preds) def f1_score(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] fscore = sklearn.metrics.f1_score(golds, preds) return np.max(fscore) def acc_all(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {} preds = list(zip(*items))[0] docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): question_id = doc["idx"]["question"] if question_id not in question_scoring_dict: question_scoring_dict[question_id] = [] gold_label = doc["label"] == 1 question_scoring_dict[question_id].append(gold_label == pred) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) return acc def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): """Compute max metric between prediction and each ground truth.""" scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def perplexity(items): return math.exp(-mean(items)) def bleu(items): """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric for evaluating a generated sentence to a reference sentence. It counts matching n-grams in the candidate translation to n-grams in the reference text, where 1-gram or unigram would be each token and a bigram comparison would be each word pair. The comparison is made regardless of word order Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ Paper: https://www.aclweb.org/anthology/P02-1040/ Higher is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_bleu(preds, refs).score def chrf(items): """chrF++ is a tool for automatic evaluation of machine translation output based on character n-gram precision and recall enhanced with word n-grams. Source: https://github.com/m-popovic/chrF Paper: https://www.aclweb.org/anthology/W15-3049.pdf Higher is better # TODO I think """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_chrf(preds, refs).score def ter(items): """Translation Error Rate is an error metric for machine translation that measures the number of edits required to change a system output into one of the references Source: http://www.cs.umd.edu/~snover/tercom/ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf Lower is better """ refs = list(zip(*items))[0] preds = list(zip(*items))[1] refs, preds = _sacreformat(refs, preds) return sacrebleu.corpus_ter(preds, refs).score def is_non_str_iterable(obj): return isinstance(obj, Iterable) and not isinstance(obj, str) def _sacreformat(refs, preds): """Format refs and preds for sacrebleu corpus calculation. It is very particular""" # Sacrebleu expects (List[str], List[List[str]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # Note [ref1_stream] is the first reference for each pred. # So lists are size N and (M, N) for N preds and M possible refs for each pred # This is a different order of dimensions that I would expect # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds if not is_non_str_iterable(refs): refs = list(refs) if not is_non_str_iterable(refs[0]): refs = [[ref] for ref in refs] refs = list(zip(*refs)) # Note the number of refs in each ref list much match the number of preds # We expect preds to be List[str] or List[List[str]]. Must become List[str] if not is_non_str_iterable(preds): preds = list(preds) if is_non_str_iterable(preds[0]): assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" preds = [pred[0] for pred in preds] return refs, preds