score.py 2.98 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
2
# Copyright (c) Facebook, Inc. and its affiliates.
Sergey Edunov's avatar
Sergey Edunov committed
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6
7
8
"""
BLEU scoring of generated translations against reference translations.
"""
Sergey Edunov's avatar
Sergey Edunov committed
9
10
11
12
13

import argparse
import os
import sys

14
from fairseq import bleu
alexeib's avatar
alexeib committed
15
from fairseq.data import dictionary
Sergey Edunov's avatar
Sergey Edunov committed
16
17


Myle Ott's avatar
Myle Ott committed
18
def get_parser():
Sergey Edunov's avatar
Sergey Edunov committed
19
    parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
20
    # fmt: off
Sergey Edunov's avatar
Sergey Edunov committed
21
    parser.add_argument('-s', '--sys', default='-', help='system output')
Sergey Edunov's avatar
Sergey Edunov committed
22
    parser.add_argument('-r', '--ref', required=True, help='references')
Sergey Edunov's avatar
Sergey Edunov committed
23
24
25
26
    parser.add_argument('-o', '--order', default=4, metavar='N',
                        type=int, help='consider ngrams up to this order')
    parser.add_argument('--ignore-case', action='store_true',
                        help='case-insensitive scoring')
Myle Ott's avatar
Myle Ott committed
27
28
    parser.add_argument('--sacrebleu', action='store_true',
                        help='score with sacrebleu')
29
30
    parser.add_argument('--sentence-bleu', action='store_true',
                        help='report sentence-level BLEUs (i.e., with +1 smoothing)')
31
    # fmt: on
Myle Ott's avatar
Myle Ott committed
32
    return parser
Sergey Edunov's avatar
Sergey Edunov committed
33

Myle Ott's avatar
Myle Ott committed
34
35
36

def main():
    parser = get_parser()
Sergey Edunov's avatar
Sergey Edunov committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    args = parser.parse_args()
    print(args)

    assert args.sys == '-' or os.path.exists(args.sys), \
        "System output file {} does not exist".format(args.sys)
    assert os.path.exists(args.ref), \
        "Reference file {} does not exist".format(args.ref)

    dict = dictionary.Dictionary()

    def readlines(fd):
        for line in fd.readlines():
            if args.ignore_case:
                yield line.lower()
51
            else:
ngimel's avatar
ngimel committed
52
                yield line
Sergey Edunov's avatar
Sergey Edunov committed
53

Myle Ott's avatar
Myle Ott committed
54
55
56
57
58
59
    if args.sacrebleu:
        import sacrebleu

        def score(fdsys):
            with open(args.ref) as fdref:
                print(sacrebleu.corpus_bleu(fdsys, [fdref]))
60
61
62
63
64
65
66
67
68
69
    elif args.sentence_bleu:
        def score(fdsys):
            with open(args.ref) as fdref:
                scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
                for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))):
                    scorer.reset(one_init=True)
                    sys_tok = dict.encode_line(sys_tok)
                    ref_tok = dict.encode_line(ref_tok)
                    scorer.add(ref_tok, sys_tok)
                    print(i, scorer.result_string(args.order))
Myle Ott's avatar
Myle Ott committed
70
71
72
73
74
    else:
        def score(fdsys):
            with open(args.ref) as fdref:
                scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
                for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
75
76
                    sys_tok = dict.encode_line(sys_tok)
                    ref_tok = dict.encode_line(ref_tok)
Myle Ott's avatar
Myle Ott committed
77
78
                    scorer.add(ref_tok, sys_tok)
                print(scorer.result_string(args.order))
Sergey Edunov's avatar
Sergey Edunov committed
79
80
81
82
83
84
85
86
87
88

    if args.sys == '-':
        score(sys.stdin)
    else:
        with open(args.sys, 'r') as f:
            score(f)


if __name__ == '__main__':
    main()