score.py 5.87 KB
Newer Older
1
#!/usr/bin/env python3
2
# Copyright (c) Facebook, Inc. and its affiliates.
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
6
"""
Myle Ott's avatar
Myle Ott committed
7
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
8
9
10
11
12
13
14
candidate hypotheses.

See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""

import argparse
Myle Ott's avatar
Myle Ott committed
15
from itertools import chain
16
17
18
import sys
import random

Myle Ott's avatar
Myle Ott committed
19
import numpy as np
Myle Ott's avatar
Myle Ott committed
20
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
21
22


Myle Ott's avatar
Myle Ott committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def main():
    parser = argparse.ArgumentParser(sys.argv[0])
    parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
                        help='path to system output')
    parser.add_argument('--ref', default='', metavar='FILE',
                        help='path to references')
    parser.add_argument('--output', default='', metavar='FILE',
                        help='print outputs into a pretty format')
    args = parser.parse_args()

    if args.sys:
        src, tgt, hypos, log_probs = load_sys(args.sys)
        print('pairwise BLEU: %.2f' % pairwise(hypos))
        if args.output:
            merge(src, tgt, hypos, log_probs, args.output)
Myle Ott's avatar
Myle Ott committed
38

Myle Ott's avatar
Myle Ott committed
39
40
41
42
43
44
45
    if args.ref:
        _, _, refs = load_ref(args.ref)
        if args.sys:
            multi_ref(refs, hypos)
        else:
            intra_ref(refs)

46
47
48
49
50

def dictolist(d):
    a = sorted(d.items(), key=lambda i: i[0])
    return [i[1] for i in a]

Myle Ott's avatar
Myle Ott committed
51

52
53
54
55
56
def load_sys(paths):
    src, tgt, hypos, log_probs = {}, {}, {}, {}
    for path in paths:
        with open(path) as f:
            for line in f:
Myle Ott's avatar
Myle Ott committed
57
                line = line.rstrip()
58
59
60
61
62
63
64
65
66
67
68
69
70
71
                if line.startswith(('S-', 'T-', 'H-')):
                    i = int(line[line.find('-')+1:line.find('\t')])
                    if line.startswith('S-'):
                        src[i] = line.split('\t')[1]
                    if line.startswith('T-'):
                        tgt[i] = line.split('\t')[1]
                    if line.startswith('H-'):
                        if i not in hypos:
                            hypos[i] = []
                            log_probs[i] = []
                        hypos[i].append(line.split('\t')[2])
                        log_probs[i].append(float(line.split('\t')[1]))
    return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)

Myle Ott's avatar
Myle Ott committed
72

73
74
75
76
77
78
79
def load_ref(path):
    with open(path) as f:
        lines = f.readlines()
    src, tgt, refs = [], [], []
    i = 0
    while i < len(lines):
        if lines[i].startswith('S-'):
Myle Ott's avatar
Myle Ott committed
80
            src.append(lines[i].split('\t')[1].rstrip())
81
82
            i += 1
        elif lines[i].startswith('T-'):
Myle Ott's avatar
Myle Ott committed
83
            tgt.append(lines[i].split('\t')[1].rstrip())
84
85
86
87
            i += 1
        else:
            a = []
            while i < len(lines) and lines[i].startswith('R'):
Myle Ott's avatar
Myle Ott committed
88
                a.append(lines[i].split('\t')[1].rstrip())
89
90
91
92
                i += 1
            refs.append(a)
    return src, tgt, refs

Myle Ott's avatar
Myle Ott committed
93

94
95
96
def merge(src, tgt, hypos, log_probs, path):
    with open(path, 'w') as f:
        for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
Myle Ott's avatar
Myle Ott committed
97
98
            f.write(s + '\n')
            f.write(t + '\n')
99
100
            f.write('\n')
            for h, lp in zip(hs, lps):
Myle Ott's avatar
Myle Ott committed
101
                f.write('\t%f\t%s\n' % (lp, h.strip()))
102
103
            f.write('------------------------------------------------------\n')

Myle Ott's avatar
Myle Ott committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

def corpus_bleu(sys_stream, ref_streams):
    bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
    return bleu.score


def sentence_bleu(hypothesis, reference):
    bleu = _corpus_bleu(hypothesis, reference)
    for i in range(1, 4):
        bleu.counts[i] += 1
        bleu.totals[i] += 1
    bleu = compute_bleu(
        bleu.counts, bleu.totals,
        bleu.sys_len, bleu.ref_len,
        smooth='exp', smooth_floor=0.0,
    )
    return bleu.score

122
123
124
125
126
127
128
129
130

def pairwise(sents):
    _ref, _hypo = [], []
    for s in sents:
        for i in range(len(s)):
            for j in range(len(s)):
                if i != j:
                    _ref.append(s[i])
                    _hypo.append(s[j])
Myle Ott's avatar
Myle Ott committed
131
132
    return corpus_bleu(_hypo, [_ref])

133
134
135
136

def multi_ref(refs, hypos):
    _ref, _hypo = [], []
    ref_cnt = 0
Myle Ott's avatar
Myle Ott committed
137
138
139
    assert len(refs) == len(hypos)

    # count number of refs covered
140
141
142
    for rs, hs in zip(refs, hypos):
        a = set()
        for h in hs:
Myle Ott's avatar
Myle Ott committed
143
            s = [sentence_bleu(h, r) for r in rs]
144
145
146
147
148
149
150
151
            j = np.argmax(s)
            _ref.append(rs[j])
            _hypo.append(h)
            best = [k for k in range(len(rs)) if s[k] == s[j]]
            a.add(random.choice(best))
        ref_cnt += len(a)
    print('#refs covered: %.2f' % (ref_cnt / len(refs)))

Myle Ott's avatar
Myle Ott committed
152
153
154
155
    # transpose refs and hypos
    refs = list(zip(*refs))
    hypos = list(zip(*hypos))

Myle Ott's avatar
Myle Ott committed
156
    # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
Myle Ott's avatar
Myle Ott committed
157
158
    k = len(hypos)
    m = len(refs)
Myle Ott's avatar
Myle Ott committed
159
160
161
162
163
164
165
166
167
168
169
    flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
    duplicated_refs = [
        [ref for ref in refs_i for _ in range(k)]
        for refs_i in refs
    ]
    loo_bleus = []
    for held_out_ref in range(m):
        remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
        assert len(remaining_refs) == m - 1
        loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
    print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
Myle Ott's avatar
Myle Ott committed
170
171


172
173
def intra_ref(refs):
    print('ref pairwise BLEU: %.2f' % pairwise(refs))
Myle Ott's avatar
Myle Ott committed
174
175
176
177
178
179
180
181
182
183
184
185
    refs = list(zip(*refs))
    m = len(refs)
    concat_h = []
    concat_rest = [[] for j in range(m - 1)]
    for i, h in enumerate(refs):
        rest = refs[:i] + refs[i+1:]
        concat_h.append(h)
        for j in range(m - 1):
            concat_rest[j].extend(rest[j])
    concat_h = list(chain.from_iterable(concat_h))
    bleu = corpus_bleu(concat_h, concat_rest)
    print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
186
187


Myle Ott's avatar
Myle Ott committed
188
189
if __name__ == '__main__':
    main()