metrics.py 4.53 KB
Newer Older
&'s avatar
& committed
1
import math
&'s avatar
& committed
2
from collections import Iterable
&'s avatar
& committed
3
from pprint import pprint
&'s avatar
& committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

import numpy as np
import sacrebleu
import sklearn


def mean(arr):
    return sum(arr) / len(arr)


def median(arr):
    return arr[len(arr) // 2]


def matthews_corrcoef(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    return sklearn.metrics.matthews_corrcoef(golds, preds)


def f1_score(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    fscore = sklearn.metrics.f1_score(golds, preds)

    return np.max(fscore)


def acc_all(items):
    # Only count as correct if all answers are labeled correctly for each question
    question_scoring_dict = {}
    preds = list(zip(*items))[0]
    docs = list(zip(*items))[1]

    for doc, pred in zip(docs, preds):
        question_id = doc["idx"]["question"]
        if question_id not in question_scoring_dict:
            question_scoring_dict[question_id] = []

        gold_label = doc["label"] == 1
        question_scoring_dict[question_id].append(gold_label == pred)

    acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
    return acc


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Compute max metric between prediction and each ground truth."""
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def perplexity(items):
    return math.exp(-mean(items))


def bleu(items):
    """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
    for evaluating a generated sentence to a reference sentence. It counts matching
    n-grams in the candidate translation to n-grams in the reference text, where
    1-gram or unigram would be each token and a bigram comparison would be each
    word pair. The comparison is made regardless of word order
    Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
    Paper: https://www.aclweb.org/anthology/P02-1040/

    Higher is better
    """
&'s avatar
metrics  
& committed
76
77
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
78
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
79
80
    return sacrebleu.corpus_bleu(preds, refs).score

&'s avatar
& committed
81
82
83
84
85
86
87
88
89

def chrf(items):
    """chrF++ is a tool for automatic evaluation of machine translation output
    based on character n-gram precision and recall enhanced with word n-grams.
    Source: https://github.com/m-popovic/chrF
    Paper: https://www.aclweb.org/anthology/W15-3049.pdf

    Higher is better  # TODO I think
    """
&'s avatar
metrics  
& committed
90
91
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
92
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
93
94
    return sacrebleu.corpus_chrf(preds, refs).score

&'s avatar
& committed
95
96
97
98
99
100
101
102
103
104

def ter(items):
    """Translation Error Rate is an error metric for machine translation that
    measures the number of edits required to change a system output into one
    of the references
    Source: http://www.cs.umd.edu/~snover/tercom/
    Paper: http://mt-archive.info/AMTA-2006-Snover.pdf

    Lower is better
    """
&'s avatar
metrics  
& committed
105
106
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
107
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
108
109
110
    return sacrebleu.corpus_ter(preds, refs).score


&'s avatar
& committed
111
112
113
114
def is_non_str_iterable(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, str)


&'s avatar
& committed
115
def _sacreformat(refs, preds):
&'s avatar
metrics  
& committed
116
    """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
&'s avatar
& committed
117
    # Sacrebleu expects (List[str], List[List[str])
&'s avatar
metrics  
& committed
118
119
    #   e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])

&'s avatar
& committed
120
121
122
123
124
125
    # Note [ref1_stream] is the first reference for each pred.
    # So lists are size N and (M, N) for N preds and M possible refs for each pred
    # This is a different order of dimensions that I would expect

    # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
    # Must become List[List[str]] with the inner list corresponding to preds
&'s avatar
& committed
126
    if not is_non_str_iterable(refs):
&'s avatar
metrics  
& committed
127
        refs = list(refs)
&'s avatar
& committed
128
    if not is_non_str_iterable(refs[0]):
&'s avatar
metrics  
& committed
129
        refs = [[ref] for ref in refs]
&'s avatar
& committed
130
131
    refs = list(zip(*refs))
    # Note the number of refs in each ref list much match the number of preds
&'s avatar
metrics  
& committed
132

&'s avatar
& committed
133
    # We expect preds to be List[str] or List[List[str]]. Must become List[str]
&'s avatar
& committed
134
    if not is_non_str_iterable(preds):
&'s avatar
metrics  
& committed
135
        preds = list(preds)
&'s avatar
& committed
136
    if is_non_str_iterable(preds[0]):
&'s avatar
& committed
137
138
        assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
        preds = [pred[0] for pred in preds]
&'s avatar
metrics  
& committed
139
140

    return refs, preds