import math import numpy as np import sacrebleu import sklearn def mean(arr): return sum(arr) / len(arr) def median(arr): return arr[len(arr) // 2] def matthews_corrcoef(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] return sklearn.metrics.matthews_corrcoef(golds, preds) def f1_score(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] fscore = sklearn.metrics.f1_score(golds, preds) return np.max(fscore) def acc_all(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {} preds = list(zip(*items))[0] docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): question_id = doc["idx"]["question"] if question_id not in question_scoring_dict: question_scoring_dict[question_id] = [] gold_label = doc["label"] == 1 question_scoring_dict[question_id].append(gold_label == pred) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) return acc def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): """Compute max metric between prediction and each ground truth.""" scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def perplexity(items): return math.exp(-mean(items)) def bleu(items): """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric for evaluating a generated sentence to a reference sentence. It counts matching n-grams in the candidate translation to n-grams in the reference text, where 1-gram or unigram would be each token and a bigram comparison would be each word pair. The comparison is made regardless of word order Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ Paper: https://www.aclweb.org/anthology/P02-1040/ Higher is better """ preds = list(zip(*items))[0] docs = list(zip(*items))[1] pass def chrf(items): """chrF++ is a tool for automatic evaluation of machine translation output based on character n-gram precision and recall enhanced with word n-grams. Source: https://github.com/m-popovic/chrF Paper: https://www.aclweb.org/anthology/W15-3049.pdf Higher is better # TODO I think """ pass def ter(items): """Translation Error Rate is an error metric for machine translation that measures the number of edits required to change a system output into one of the references Source: http://www.cs.umd.edu/~snover/tercom/ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf Lower is better """ pass