metrics.py 9.96 KB
Newer Older
jon-tow's avatar
jon-tow committed
1
import typing
&'s avatar
& committed
2
import math
3
from collections.abc import Iterable
&'s avatar
& committed
4
5
6

import numpy as np
import sacrebleu
jon-tow's avatar
jon-tow committed
7
from rouge_score import rouge_scorer
Jonathan Tow's avatar
Jonathan Tow committed
8
import sklearn.metrics
Leo Gao's avatar
Leo Gao committed
9
import random
&'s avatar
& committed
10
11
12
13
14
15


def mean(arr):
    return sum(arr) / len(arr)


Leo Gao's avatar
Leo Gao committed
16
def pop_stddev(arr):
Leo Gao's avatar
Leo Gao committed
17
18
19
20
    mu = mean(arr)
    return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))


Leo Gao's avatar
Leo Gao committed
21
22
23
24
25
def sample_stddev(arr):
    mu = mean(arr)
    return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))


Leo Gao's avatar
Leo Gao committed
26
def mean_stderr(arr):
Leo Gao's avatar
Leo Gao committed
27
    return sample_stddev(arr) / math.sqrt(len(arr))
Leo Gao's avatar
Leo Gao committed
28
29


&'s avatar
& committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def median(arr):
    return arr[len(arr) // 2]


def matthews_corrcoef(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    return sklearn.metrics.matthews_corrcoef(golds, preds)


def f1_score(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    fscore = sklearn.metrics.f1_score(golds, preds)

    return np.max(fscore)


def acc_all(items):
    # Only count as correct if all answers are labeled correctly for each question
    question_scoring_dict = {}
    preds = list(zip(*items))[0]
    docs = list(zip(*items))[1]

    for doc, pred in zip(docs, preds):
57
        paragraph_id = doc["idx"]["paragraph"]
&'s avatar
& committed
58
        question_id = doc["idx"]["question"]
59
60
        if (paragraph_id, question_id) not in question_scoring_dict:
            question_scoring_dict[(paragraph_id, question_id)] = []
&'s avatar
& committed
61
62
63

        gold_label = doc["label"] == 1

64
        question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
&'s avatar
& committed
65
66
67
    acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
    return acc

68

Leo Gao's avatar
Leo Gao committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def acc_all_stderr(items):
    # Only count as correct if all answers are labeled correctly for each question
    question_scoring_dict = {}
    preds = list(zip(*items))[0]
    docs = list(zip(*items))[1]

    for doc, pred in zip(docs, preds):
        question_id = doc["idx"]["question"]
        if question_id not in question_scoring_dict:
            question_scoring_dict[question_id] = []

        gold_label = doc["label"] == 1
        question_scoring_dict[question_id].append(gold_label == pred)

    acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
    return acc

&'s avatar
& committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Compute max metric between prediction and each ground truth."""
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def perplexity(items):
    return math.exp(-mean(items))


Leo Gao's avatar
Leo Gao committed
100
101
102
103
def weighted_mean(items):
    a, b = zip(*items)
    return sum(a) / sum(b)

104

Leo Gao's avatar
Leo Gao committed
105
106
107
def weighted_perplexity(items):
    return math.exp(-weighted_mean(items))

108
109
110
def bits_per_byte(items):
    return -weighted_mean(items) / math.log(2)

Leo Gao's avatar
Leo Gao committed
111

&'s avatar
& committed
112
113
114
115
116
117
118
119
120
121
122
def bleu(items):
    """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
    for evaluating a generated sentence to a reference sentence. It counts matching
    n-grams in the candidate translation to n-grams in the reference text, where
    1-gram or unigram would be each token and a bigram comparison would be each
    word pair. The comparison is made regardless of word order
    Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
    Paper: https://www.aclweb.org/anthology/P02-1040/

    Higher is better
    """
&'s avatar
metrics  
& committed
123
124
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
125
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
126
127
    return sacrebleu.corpus_bleu(preds, refs).score

&'s avatar
& committed
128
129
130
131
132
133
134
135
136

def chrf(items):
    """chrF++ is a tool for automatic evaluation of machine translation output
    based on character n-gram precision and recall enhanced with word n-grams.
    Source: https://github.com/m-popovic/chrF
    Paper: https://www.aclweb.org/anthology/W15-3049.pdf

    Higher is better  # TODO I think
    """
&'s avatar
metrics  
& committed
137
138
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
139
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
140
141
    return sacrebleu.corpus_chrf(preds, refs).score

&'s avatar
& committed
142
143
144
145
146
147
148
149
150
151

def ter(items):
    """Translation Error Rate is an error metric for machine translation that
    measures the number of edits required to change a system output into one
    of the references
    Source: http://www.cs.umd.edu/~snover/tercom/
    Paper: http://mt-archive.info/AMTA-2006-Snover.pdf

    Lower is better
    """
&'s avatar
metrics  
& committed
152
153
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
&'s avatar
& committed
154
    refs, preds = _sacreformat(refs, preds)
&'s avatar
metrics  
& committed
155
156
157
    return sacrebleu.corpus_ter(preds, refs).score


&'s avatar
& committed
158
159
160
161
def is_non_str_iterable(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, str)


&'s avatar
& committed
162
def _sacreformat(refs, preds):
&'s avatar
metrics  
& committed
163
    """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
&'s avatar
& committed
164
    # Sacrebleu expects (List[str], List[List[str])
&'s avatar
metrics  
& committed
165
166
    #   e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])

&'s avatar
& committed
167
168
169
170
171
172
    # Note [ref1_stream] is the first reference for each pred.
    # So lists are size N and (M, N) for N preds and M possible refs for each pred
    # This is a different order of dimensions that I would expect

    # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
    # Must become List[List[str]] with the inner list corresponding to preds
&'s avatar
& committed
173
    if not is_non_str_iterable(refs):
&'s avatar
metrics  
& committed
174
        refs = list(refs)
&'s avatar
& committed
175
    if not is_non_str_iterable(refs[0]):
&'s avatar
metrics  
& committed
176
        refs = [[ref] for ref in refs]
&'s avatar
& committed
177
178
    refs = list(zip(*refs))
    # Note the number of refs in each ref list much match the number of preds
&'s avatar
metrics  
& committed
179

&'s avatar
& committed
180
    # We expect preds to be List[str] or List[List[str]]. Must become List[str]
&'s avatar
& committed
181
    if not is_non_str_iterable(preds):
&'s avatar
metrics  
& committed
182
        preds = list(preds)
&'s avatar
& committed
183
    if is_non_str_iterable(preds[0]):
&'s avatar
& committed
184
185
        assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
        preds = [pred[0] for pred in preds]
&'s avatar
metrics  
& committed
186
187

    return refs, preds
Leo Gao's avatar
Leo Gao committed
188

jon-tow's avatar
jon-tow committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

def rouge(
    refs: typing.List[str],
    pred: str,
    rouge_types: typing.List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
):
    """ ROUGE with multi-reference support

    Implementation based on GEM-metrics:
    https://github.com/GEM-benchmark/GEM-metrics/blob/431a8174bd6b3637e8d6118bfad2983e39e99733/gem_metrics/rouge.py

    :param refs:
        A `list` of reference `str`s.
    :param pred:
        A single prediction `str`s.
    """
205
206
207
208
209
210
211
212
213

    # Add newlines between sentences to correctly compute `rougeLsum`.
    if "rougeLsum" in rouge_types:
        # TODO: Adapt this to handle languages that do not support sentence endings by `.`.
        # See GEM-metrics implementation with lang specific `nltk` tokenizers to
        # split sentences.
        pred = pred.replace(".", ".\n")
        refs = [ref.replace(".", ".\n") for ref in refs]

jon-tow's avatar
jon-tow committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
    # ROUGE multi-ref jackknifing
    if len(refs) > 1:
        cur_scores = [scorer.score(ref, pred) for ref in refs]

        # get best score for all leave-one-out sets
        best_scores = []
        for leave in range(len(refs)):
            cur_scores_leave_one = [
                cur_scores[s] for s in range(len(refs)) if s != leave
            ]
            best_scores.append(
                {
                    rouge_type: max(
                        [s[rouge_type] for s in cur_scores_leave_one],
                        key=lambda s: s.fmeasure,
                    )
                    for rouge_type in rouge_types
                }
            )
        # average the leave-one-out bests to produce the final score
        score = {
            rouge_type: rouge_scorer.scoring.Score(
                np.mean([b[rouge_type].precision for b in best_scores]),
                np.mean([b[rouge_type].recall for b in best_scores]),
                np.mean([b[rouge_type].fmeasure for b in best_scores]),
            )
            for rouge_type in rouge_types
        }
    else:
        score = scorer.score(refs[0], pred)
    # convert the named tuples to plain nested dicts
    score = {
        rouge_type: {
            "precision": score[rouge_type].precision,
            "recall": score[rouge_type].recall,
            "fmeasure": score[rouge_type].fmeasure,
        }
        for rouge_type in rouge_types
    }
    return score


257
# stderr stuff
Leo Gao's avatar
Leo Gao committed
258

Leo Gao's avatar
Leo Gao committed
259
260
261
262
class _bootstrap_internal:
    def __init__(self, f, n):
        self.f = f
        self.n = n
263

Leo Gao's avatar
Leo Gao committed
264
265
266
267
268
269
270
271
272
    def __call__(self, v):
        i, xs = v
        rnd = random.Random()
        rnd.seed(i)
        res = []
        for _ in range(self.n):
            res.append(self.f(rnd.choices(xs, k=len(xs))))
        return res

Leo Gao's avatar
Leo Gao committed
273

274
def bootstrap_stderr(f, xs, iters):
Leo Gao's avatar
Leo Gao committed
275
276
    import multiprocessing as mp
    pool = mp.Pool(mp.cpu_count())
Leo Gao's avatar
Leo Gao committed
277
278
279
280
281
282
    # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
    # equivalent to stderr calculated without Bessel's correction in the stddev. 
    # Unfortunately, I haven't been able to figure out what the right correction is
    # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
    # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
    # Thankfully, shouldn't matter because our samples are pretty big usually anyways
Leo Gao's avatar
Leo Gao committed
283
    res = []
284
    chunk_size = min(1000, iters)
Leo Gao's avatar
Leo Gao committed
285
    from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
286
    print("bootstrapping for stddev:", f.__name__)
287
288
289
    for bootstrap in tqdm(pool.imap(
            _bootstrap_internal(f, chunk_size),
            [(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
Leo Gao's avatar
Leo Gao committed
290
        # sample w replacement
Leo Gao's avatar
Leo Gao committed
291
        res.extend(bootstrap)
Leo Gao's avatar
Leo Gao committed
292

Leo Gao's avatar
Leo Gao committed
293
    pool.close()
Leo Gao's avatar
Leo Gao committed
294
    return sample_stddev(res)
Leo Gao's avatar
Leo Gao committed
295
296


297
def stderr_for_metric(metric, bootstrap_iters):
Leo Gao's avatar
Leo Gao committed
298
299
300
301
302
303
304
305
306
307
308
    bootstrappable = [
        median,
        matthews_corrcoef,
        f1_score,
        perplexity,
        bleu,
        chrf,
        ter,
    ]

    if metric in bootstrappable:
309
        return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
Leo Gao's avatar
Leo Gao committed
310
311
312
313
314
315
316

    stderr = {
        mean: mean_stderr,
        acc_all: acc_all_stderr
        
    }

Leo Gao's avatar
Leo Gao committed
317
    return stderr.get(metric, None)
Jonathan Tow's avatar
Jonathan Tow committed
318
319
320
321
322
323
324


def yesno(x):
    if x:
        return 'yes'
    else:
        return 'no'