compute_bleu.py 4.63 KB
Newer Older
Katherine Wu's avatar
Katherine Wu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.

Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import re
import sys
import unicodedata

# pylint: disable=g-bad-import-order
import six
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer.utils import metrics


class UnicodeRegex(object):
  """Ad-hoc hack to recognize all punctuation and symbols."""

  def __init__(self):
    punctuation = self.property_chars("P")
    self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
    self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
    self.symbol_re = re.compile("([" + self.property_chars("S") + "])")

  def property_chars(self, prefix):
    return "".join(six.unichr(x) for x in range(sys.maxunicode)
                   if unicodedata.category(six.unichr(x)).startswith(prefix))


uregex = UnicodeRegex()


def bleu_tokenize(string):
  r"""Tokenize a string following the official BLEU implementation.

  See https://github.com/moses-smt/mosesdecoder/'
           'blob/master/scripts/generic/mteval-v14.pl#L954-L983
  In our case, the input string is expected to be just one line
  and no HTML entities de-escaping is needed.
  So we just tokenize on punctuation and symbols,
  except when a punctuation is preceded and followed by a digit
  (e.g. a comma/dot as a thousand/decimal separator).

  Note that a numer (e.g. a year) followed by a dot at the end of sentence
  is NOT tokenized,
  i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
  does not match this case (unless we add a space after each sentence).
  However, this error is already in the original mteval-v14.pl
  and we want to be consistent with it.

  Args:
    string: the input string

  Returns:
    a list of tokens
  """
  string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
  string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
  string = uregex.symbol_re.sub(r" \1 ", string)
  return string.split()


def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
  """Compute BLEU for two files (reference and hypothesis translation)."""
  ref_lines = tf.gfile.Open(ref_filename).read().strip().splitlines()
  hyp_lines = tf.gfile.Open(hyp_filename).read().strip().splitlines()

  if len(ref_lines) != len(hyp_lines):
    raise ValueError("Reference and translation files have different number of "
                     "lines.")
  if not case_sensitive:
    ref_lines = [x.lower() for x in ref_lines]
    hyp_lines = [x.lower() for x in hyp_lines]
  ref_tokens = [bleu_tokenize(x) for x in ref_lines]
  hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
  return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100


def main(unused_argv):
  if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
    score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
    print("Case-insensitive results:", score)

  if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
    score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
    print("Case-sensitive results:", score)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "--translation", "-t", type=str, default=None, required=True,
      help="[default: %(default)s] File containing translated text.",
      metavar="<T>")
  parser.add_argument(
      "--reference", "-r", type=str, default=None, required=True,
      help="[default: %(default)s] File containing reference translation",
      metavar="<R>")
  parser.add_argument(
      "--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
      nargs="*", default=None,
      help="Specify one or more BLEU variants to calculate (both are "
           "calculated by default. Variants: \"cased\" or \"uncased\".",
      metavar="<BV>")

  FLAGS, unparsed = parser.parse_known_args()
  main(sys.argv)