bleu.py 7.68 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
#!/usr/bin/python
2
3
4
5
6
7
8
9
import os
import re
import sys
import math
import subprocess
import xml.sax.saxutils

from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
lintangsutawika's avatar
lintangsutawika committed
10

lintangsutawika's avatar
lintangsutawika committed
11
12
13
"""
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
"""
lintangsutawika's avatar
lintangsutawika committed
14
15
16

# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $

lintangsutawika's avatar
lintangsutawika committed
17
"""Provides:
lintangsutawika's avatar
lintangsutawika committed
18
19
20
21
22
23
24
25

cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
score_cooked(alltest, n=4): Score a list of cooked test sentences.

score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.

The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
lintangsutawika's avatar
lintangsutawika committed
26
"""
lintangsutawika's avatar
lintangsutawika committed
27
28
29
30
31
32
33

# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
nonorm = 0

preserve_case = False
eff_ref_len = "shortest"

34
normalize1: List[Tuple[Union[Pattern[str], str], str]] = [
lintangsutawika's avatar
lintangsutawika committed
35
36
37
38
    ("<skipped>", ""),  # strip "skipped" tags
    (r"-\n", ""),  # strip end-of-line hyphenation and join lines
    (r"\n", " "),  # join lines
    #    (r'(\d)\s+(?=\d)', r'\1'), # join digits
lintangsutawika's avatar
lintangsutawika committed
39
40
41
]
normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]

42
normalize2: List[Tuple[Union[Pattern[str], str], str]] = [
lintangsutawika's avatar
lintangsutawika committed
43
44
45
46
47
48
49
50
51
52
53
54
55
    (
        r"([\{-\~\[-\` -\&\(-\+\:-\@\/])",
        r" \1 ",
    ),  # tokenize punctuation. apostrophe is missing
    (
        r"([^0-9])([\.,])",
        r"\1 \2 ",
    ),  # tokenize period and comma unless preceded by a digit
    (
        r"([\.,])([^0-9])",
        r" \1 \2",
    ),  # tokenize period and comma unless followed by a digit
    (r"([0-9])(-)", r"\1 \2 "),  # tokenize dash when preceded by a digit
lintangsutawika's avatar
lintangsutawika committed
56
57
58
]
normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]

lintangsutawika's avatar
lintangsutawika committed
59

lintangsutawika's avatar
lintangsutawika committed
60
def normalize(s):
lintangsutawika's avatar
lintangsutawika committed
61
    """Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl."""
lintangsutawika's avatar
lintangsutawika committed
62
    # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
lintangsutawika's avatar
lintangsutawika committed
63
    if nonorm:
lintangsutawika's avatar
lintangsutawika committed
64
65
66
67
68
69
        return s.split()
    if type(s) is not str:
        s = " ".join(s)
    # language-independent part:
    for (pattern, replace) in normalize1:
        s = re.sub(pattern, replace, s)
lintangsutawika's avatar
lintangsutawika committed
70
    s = xml.sax.saxutils.unescape(s, {"&quot;": '"'})
lintangsutawika's avatar
lintangsutawika committed
71
72
73
    # language-dependent part (assuming Western languages):
    s = " %s " % s
    if not preserve_case:
lintangsutawika's avatar
lintangsutawika committed
74
        s = s.lower()  # this might not be identical to the original
lintangsutawika's avatar
lintangsutawika committed
75
76
77
78
    for (pattern, replace) in normalize2:
        s = re.sub(pattern, replace, s)
    return s.split()

lintangsutawika's avatar
lintangsutawika committed
79

lintangsutawika's avatar
lintangsutawika committed
80
def count_ngrams(words, n=4):
81
    counts: Dict[Any, int] = {}
lintangsutawika's avatar
lintangsutawika committed
82
83
84
85
    for k in range(1, n + 1):
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i : i + k])
            counts[ngram] = counts.get(ngram, 0) + 1
lintangsutawika's avatar
lintangsutawika committed
86
87
    return counts

lintangsutawika's avatar
lintangsutawika committed
88

lintangsutawika's avatar
lintangsutawika committed
89
def cook_refs(refs, n=4):
lintangsutawika's avatar
lintangsutawika committed
90
    """Takes a list of reference sentences for a single segment
lintangsutawika's avatar
lintangsutawika committed
91
    and returns an object that encapsulates everything that BLEU
lintangsutawika's avatar
lintangsutawika committed
92
93
    needs to know about them."""

lintangsutawika's avatar
lintangsutawika committed
94
    refs = [normalize(ref) for ref in refs]
95
    maxcounts: Dict[Tuple[str], int] = {}
lintangsutawika's avatar
lintangsutawika committed
96
97
    for ref in refs:
        counts = count_ngrams(ref, n)
lintangsutawika's avatar
lintangsutawika committed
98
99
        for (ngram, count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
100
101
    return ([len(ref) for ref in refs], maxcounts)

lintangsutawika's avatar
lintangsutawika committed
102

lintangsutawika's avatar
lintangsutawika committed
103
def cook_test(test, item, n=4):
lintangsutawika's avatar
lintangsutawika committed
104
105
106
    """Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it."""
    (reflens, refmaxcounts) = item
lintangsutawika's avatar
lintangsutawika committed
107
    test = normalize(test)
108
    result: Dict[str, Any] = {}
lintangsutawika's avatar
lintangsutawika committed
109
110
111
    result["testlen"] = len(test)

    # Calculate effective reference sentence length.
lintangsutawika's avatar
lintangsutawika committed
112

lintangsutawika's avatar
lintangsutawika committed
113
114
115
    if eff_ref_len == "shortest":
        result["reflen"] = min(reflens)
    elif eff_ref_len == "average":
lintangsutawika's avatar
lintangsutawika committed
116
        result["reflen"] = float(sum(reflens)) / len(reflens)
lintangsutawika's avatar
lintangsutawika committed
117
    elif eff_ref_len == "closest":
118
        min_diff: Optional[int] = None
lintangsutawika's avatar
lintangsutawika committed
119
        for reflen in reflens:
lintangsutawika's avatar
lintangsutawika committed
120
121
122
            if min_diff is None or abs(reflen - len(test)) < min_diff:
                min_diff = abs(reflen - len(test))
                result["reflen"] = reflen
lintangsutawika's avatar
lintangsutawika committed
123

lintangsutawika's avatar
lintangsutawika committed
124
    result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)]
lintangsutawika's avatar
lintangsutawika committed
125

lintangsutawika's avatar
lintangsutawika committed
126
    result["correct"] = [0] * n
lintangsutawika's avatar
lintangsutawika committed
127
128
    counts = count_ngrams(test, n)
    for (ngram, count) in counts.items():
lintangsutawika's avatar
lintangsutawika committed
129
        result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
130
131
132

    return result

lintangsutawika's avatar
lintangsutawika committed
133

lintangsutawika's avatar
lintangsutawika committed
134
def score_cooked(allcomps, n=4, ground=0, smooth=1):
135
136
137
138
139
140
    totalcomps: Dict[str, Any] = {
        "testlen": 0,
        "reflen": 0,
        "guess": [0] * n,
        "correct": [0] * n,
    }
lintangsutawika's avatar
lintangsutawika committed
141
    for comps in allcomps:
lintangsutawika's avatar
lintangsutawika committed
142
        for key in ["testlen", "reflen"]:
lintangsutawika's avatar
lintangsutawika committed
143
            totalcomps[key] += comps[key]
lintangsutawika's avatar
lintangsutawika committed
144
        for key in ["guess", "correct"]:
lintangsutawika's avatar
lintangsutawika committed
145
146
147
            for k in range(n):
                totalcomps[key][k] += comps[key][k]
    logbleu = 0.0
148
    all_bleus: List[float] = []
lintangsutawika's avatar
lintangsutawika committed
149
    for k in range(n):
lintangsutawika's avatar
lintangsutawika committed
150
151
152
153
154
155
156
157
158
        correct = totalcomps["correct"][k]
        guess = totalcomps["guess"][k]
        addsmooth = 0
        if smooth == 1 and k > 0:
            addsmooth = 1
        logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(
            guess + addsmooth + sys.float_info.min
        )
        if guess == 0:
159
            all_bleus.append(-10000000.0)
lintangsutawika's avatar
lintangsutawika committed
160
161
        else:
            all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess))
lintangsutawika's avatar
lintangsutawika committed
162
163
164
165

    logbleu /= float(n)
    all_bleus.insert(0, logbleu)

lintangsutawika's avatar
lintangsutawika committed
166
167
168
    brevPenalty = min(
        0, 1 - float(totalcomps["reflen"] + 1) / (totalcomps["testlen"] + 1)
    )
lintangsutawika's avatar
lintangsutawika committed
169
    for i in range(len(all_bleus)):
lintangsutawika's avatar
lintangsutawika committed
170
171
172
        if i == 0:
            all_bleus[i] += brevPenalty
        all_bleus[i] = math.exp(all_bleus[i])
lintangsutawika's avatar
lintangsutawika committed
173
174
    return all_bleus

lintangsutawika's avatar
lintangsutawika committed
175
176

def bleu(refs, candidate, ground=0, smooth=1):
lintangsutawika's avatar
lintangsutawika committed
177
178
179
180
    refs = cook_refs(refs)
    test = cook_test(candidate, refs)
    return score_cooked([test], ground=ground, smooth=smooth)

lintangsutawika's avatar
lintangsutawika committed
181

lintangsutawika's avatar
lintangsutawika committed
182
def splitPuncts(line):
lintangsutawika's avatar
lintangsutawika committed
183
184
    return " ".join(re.findall(r"[\w]+|[^\s\w]", line))

lintangsutawika's avatar
lintangsutawika committed
185
186

def computeMaps(predictions, goldfile):
187
188
    predictionMap: Dict[str, list] = {}
    goldMap: Dict[str, list] = {}
lintangsutawika's avatar
lintangsutawika committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    gf = open(goldfile, "r")

    for row in predictions:
        cols = row.strip().split("\t")
        if len(cols) == 1:
            (rid, pred) = (cols[0], "")
        else:
            (rid, pred) = (cols[0], cols[1])
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]

    for row in gf:
        (rid, pred) = row.split("\t")
        if rid in predictionMap:  # Only insert if the id exists for the method
            if rid not in goldMap:
                goldMap[rid] = []
            goldMap[rid].append(splitPuncts(pred.strip().lower()))

    sys.stderr.write("Total: " + str(len(goldMap)) + "\n")
    return (goldMap, predictionMap)


# m1 is the reference map
# m2 is the prediction map
lintangsutawika's avatar
lintangsutawika committed
212
def bleuFromMaps(m1, m2):
lintangsutawika's avatar
lintangsutawika committed
213
214
    score = [0] * 5
    num = 0.0
lintangsutawika's avatar
lintangsutawika committed
215

lintangsutawika's avatar
lintangsutawika committed
216
217
218
219
220
221
    for key in m1:
        if key in m2:
            bl = bleu(m1[key], m2[key][0])
            score = [score[i] + bl[i] for i in range(0, len(bl))]
            num += 1
    return [s * 100.0 / num for s in score]
lintangsutawika's avatar
lintangsutawika committed
222
223
224
225
226
227
228
229


def smoothed_bleu_4(references, predictions, **kwargs):

    predictionMap = {}
    goldMap = {}

    for rid, pred in enumerate(predictions):
lintangsutawika's avatar
lintangsutawika committed
230
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
231
232

    for rid, row in enumerate(references):
lintangsutawika's avatar
lintangsutawika committed
233
        goldMap[rid] = [splitPuncts(row.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
234
235
236

    return bleuFromMaps(goldMap, predictionMap)[0]

lintangsutawika's avatar
lintangsutawika committed
237
238
239
240
241
242
243
244

if __name__ == "__main__":
    reference_file = sys.argv[1]
    predictions = []
    for row in sys.stdin:
        predictions.append(row)
    (goldMap, predictionMap) = computeMaps(predictions, reference_file)
    print(bleuFromMaps(goldMap, predictionMap)[0])