score.py 5.97 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Myle Ott's avatar
Myle Ott committed
9
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
10
11
12
13
14
15
16
candidate hypotheses.

See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""

import argparse
Myle Ott's avatar
Myle Ott committed
17
from itertools import chain
18
19
20
import sys
import random

Myle Ott's avatar
Myle Ott committed
21
import numpy as np
Myle Ott's avatar
Myle Ott committed
22
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
23
24


Myle Ott's avatar
Myle Ott committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def main():
    parser = argparse.ArgumentParser(sys.argv[0])
    parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
                        help='path to system output')
    parser.add_argument('--ref', default='', metavar='FILE',
                        help='path to references')
    parser.add_argument('--output', default='', metavar='FILE',
                        help='print outputs into a pretty format')
    args = parser.parse_args()

    if args.sys:
        src, tgt, hypos, log_probs = load_sys(args.sys)
        print('pairwise BLEU: %.2f' % pairwise(hypos))
        if args.output:
            merge(src, tgt, hypos, log_probs, args.output)
Myle Ott's avatar
Myle Ott committed
40

Myle Ott's avatar
Myle Ott committed
41
42
43
44
45
46
47
    if args.ref:
        _, _, refs = load_ref(args.ref)
        if args.sys:
            multi_ref(refs, hypos)
        else:
            intra_ref(refs)

48
49
50
51
52

def dictolist(d):
    a = sorted(d.items(), key=lambda i: i[0])
    return [i[1] for i in a]

Myle Ott's avatar
Myle Ott committed
53

54
55
56
57
58
def load_sys(paths):
    src, tgt, hypos, log_probs = {}, {}, {}, {}
    for path in paths:
        with open(path) as f:
            for line in f:
Myle Ott's avatar
Myle Ott committed
59
                line = line.rstrip()
60
61
62
63
64
65
66
67
68
69
70
71
72
73
                if line.startswith(('S-', 'T-', 'H-')):
                    i = int(line[line.find('-')+1:line.find('\t')])
                    if line.startswith('S-'):
                        src[i] = line.split('\t')[1]
                    if line.startswith('T-'):
                        tgt[i] = line.split('\t')[1]
                    if line.startswith('H-'):
                        if i not in hypos:
                            hypos[i] = []
                            log_probs[i] = []
                        hypos[i].append(line.split('\t')[2])
                        log_probs[i].append(float(line.split('\t')[1]))
    return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)

Myle Ott's avatar
Myle Ott committed
74

75
76
77
78
79
80
81
def load_ref(path):
    with open(path) as f:
        lines = f.readlines()
    src, tgt, refs = [], [], []
    i = 0
    while i < len(lines):
        if lines[i].startswith('S-'):
Myle Ott's avatar
Myle Ott committed
82
            src.append(lines[i].split('\t')[1].rstrip())
83
84
            i += 1
        elif lines[i].startswith('T-'):
Myle Ott's avatar
Myle Ott committed
85
            tgt.append(lines[i].split('\t')[1].rstrip())
86
87
88
89
            i += 1
        else:
            a = []
            while i < len(lines) and lines[i].startswith('R'):
Myle Ott's avatar
Myle Ott committed
90
                a.append(lines[i].split('\t')[1].rstrip())
91
92
93
94
                i += 1
            refs.append(a)
    return src, tgt, refs

Myle Ott's avatar
Myle Ott committed
95

96
97
98
def merge(src, tgt, hypos, log_probs, path):
    with open(path, 'w') as f:
        for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
Myle Ott's avatar
Myle Ott committed
99
100
            f.write(s + '\n')
            f.write(t + '\n')
101
102
            f.write('\n')
            for h, lp in zip(hs, lps):
Myle Ott's avatar
Myle Ott committed
103
                f.write('\t%f\t%s\n' % (lp, h.strip()))
104
105
            f.write('------------------------------------------------------\n')

Myle Ott's avatar
Myle Ott committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

def corpus_bleu(sys_stream, ref_streams):
    bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
    return bleu.score


def sentence_bleu(hypothesis, reference):
    bleu = _corpus_bleu(hypothesis, reference)
    for i in range(1, 4):
        bleu.counts[i] += 1
        bleu.totals[i] += 1
    bleu = compute_bleu(
        bleu.counts, bleu.totals,
        bleu.sys_len, bleu.ref_len,
        smooth='exp', smooth_floor=0.0,
    )
    return bleu.score

124
125
126
127
128
129
130
131
132

def pairwise(sents):
    _ref, _hypo = [], []
    for s in sents:
        for i in range(len(s)):
            for j in range(len(s)):
                if i != j:
                    _ref.append(s[i])
                    _hypo.append(s[j])
Myle Ott's avatar
Myle Ott committed
133
134
    return corpus_bleu(_hypo, [_ref])

135
136
137
138

def multi_ref(refs, hypos):
    _ref, _hypo = [], []
    ref_cnt = 0
Myle Ott's avatar
Myle Ott committed
139
140
141
    assert len(refs) == len(hypos)

    # count number of refs covered
142
143
144
    for rs, hs in zip(refs, hypos):
        a = set()
        for h in hs:
Myle Ott's avatar
Myle Ott committed
145
            s = [sentence_bleu(h, r) for r in rs]
146
147
148
149
150
151
152
153
            j = np.argmax(s)
            _ref.append(rs[j])
            _hypo.append(h)
            best = [k for k in range(len(rs)) if s[k] == s[j]]
            a.add(random.choice(best))
        ref_cnt += len(a)
    print('#refs covered: %.2f' % (ref_cnt / len(refs)))

Myle Ott's avatar
Myle Ott committed
154
155
156
157
    # transpose refs and hypos
    refs = list(zip(*refs))
    hypos = list(zip(*hypos))

Myle Ott's avatar
Myle Ott committed
158
    # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
Myle Ott's avatar
Myle Ott committed
159
160
    k = len(hypos)
    m = len(refs)
Myle Ott's avatar
Myle Ott committed
161
162
163
164
165
166
167
168
169
170
171
    flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
    duplicated_refs = [
        [ref for ref in refs_i for _ in range(k)]
        for refs_i in refs
    ]
    loo_bleus = []
    for held_out_ref in range(m):
        remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
        assert len(remaining_refs) == m - 1
        loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
    print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
Myle Ott's avatar
Myle Ott committed
172
173


174
175
def intra_ref(refs):
    print('ref pairwise BLEU: %.2f' % pairwise(refs))
Myle Ott's avatar
Myle Ott committed
176
177
178
179
180
181
182
183
184
185
186
187
    refs = list(zip(*refs))
    m = len(refs)
    concat_h = []
    concat_rest = [[] for j in range(m - 1)]
    for i, h in enumerate(refs):
        rest = refs[:i] + refs[i+1:]
        concat_h.append(h)
        for j in range(m - 1):
            concat_rest[j].extend(rest[j])
    concat_h = list(chain.from_iterable(concat_h))
    bleu = corpus_bleu(concat_h, concat_rest)
    print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
188
189


Myle Ott's avatar
Myle Ott committed
190
191
if __name__ == '__main__':
    main()