bleu.py 7.34 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
#!/usr/bin/python

lintangsutawika's avatar
lintangsutawika committed
3
4
5
"""
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
"""
lintangsutawika's avatar
lintangsutawika committed
6
7
8

# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $

lintangsutawika's avatar
lintangsutawika committed
9
"""Provides:
lintangsutawika's avatar
lintangsutawika committed
10
11
12
13
14
15
16
17

cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
score_cooked(alltest, n=4): Score a list of cooked test sentences.

score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.

The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
lintangsutawika's avatar
lintangsutawika committed
18
"""
lintangsutawika's avatar
lintangsutawika committed
19
20
21
22
23
24
25
26
27
28
29
30

import sys, math, re, xml.sax.saxutils
import subprocess
import os

# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
nonorm = 0

preserve_case = False
eff_ref_len = "shortest"

normalize1 = [
lintangsutawika's avatar
lintangsutawika committed
31
32
33
34
    ("<skipped>", ""),  # strip "skipped" tags
    (r"-\n", ""),  # strip end-of-line hyphenation and join lines
    (r"\n", " "),  # join lines
    #    (r'(\d)\s+(?=\d)', r'\1'), # join digits
lintangsutawika's avatar
lintangsutawika committed
35
36
37
38
]
normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]

normalize2 = [
lintangsutawika's avatar
lintangsutawika committed
39
40
41
42
43
44
45
46
47
48
49
50
51
    (
        r"([\{-\~\[-\` -\&\(-\+\:-\@\/])",
        r" \1 ",
    ),  # tokenize punctuation. apostrophe is missing
    (
        r"([^0-9])([\.,])",
        r"\1 \2 ",
    ),  # tokenize period and comma unless preceded by a digit
    (
        r"([\.,])([^0-9])",
        r" \1 \2",
    ),  # tokenize period and comma unless followed by a digit
    (r"([0-9])(-)", r"\1 \2 "),  # tokenize dash when preceded by a digit
lintangsutawika's avatar
lintangsutawika committed
52
53
54
]
normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]

lintangsutawika's avatar
lintangsutawika committed
55

lintangsutawika's avatar
lintangsutawika committed
56
def normalize(s):
lintangsutawika's avatar
lintangsutawika committed
57
    """Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl."""
lintangsutawika's avatar
lintangsutawika committed
58
    # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
lintangsutawika's avatar
lintangsutawika committed
59
    if nonorm:
lintangsutawika's avatar
lintangsutawika committed
60
61
62
63
64
65
        return s.split()
    if type(s) is not str:
        s = " ".join(s)
    # language-independent part:
    for (pattern, replace) in normalize1:
        s = re.sub(pattern, replace, s)
lintangsutawika's avatar
lintangsutawika committed
66
    s = xml.sax.saxutils.unescape(s, {"&quot;": '"'})
lintangsutawika's avatar
lintangsutawika committed
67
68
69
    # language-dependent part (assuming Western languages):
    s = " %s " % s
    if not preserve_case:
lintangsutawika's avatar
lintangsutawika committed
70
        s = s.lower()  # this might not be identical to the original
lintangsutawika's avatar
lintangsutawika committed
71
72
73
74
    for (pattern, replace) in normalize2:
        s = re.sub(pattern, replace, s)
    return s.split()

lintangsutawika's avatar
lintangsutawika committed
75

lintangsutawika's avatar
lintangsutawika committed
76
77
def count_ngrams(words, n=4):
    counts = {}
lintangsutawika's avatar
lintangsutawika committed
78
79
80
81
    for k in range(1, n + 1):
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i : i + k])
            counts[ngram] = counts.get(ngram, 0) + 1
lintangsutawika's avatar
lintangsutawika committed
82
83
    return counts

lintangsutawika's avatar
lintangsutawika committed
84

lintangsutawika's avatar
lintangsutawika committed
85
def cook_refs(refs, n=4):
lintangsutawika's avatar
lintangsutawika committed
86
    """Takes a list of reference sentences for a single segment
lintangsutawika's avatar
lintangsutawika committed
87
    and returns an object that encapsulates everything that BLEU
lintangsutawika's avatar
lintangsutawika committed
88
89
    needs to know about them."""

lintangsutawika's avatar
lintangsutawika committed
90
91
92
93
    refs = [normalize(ref) for ref in refs]
    maxcounts = {}
    for ref in refs:
        counts = count_ngrams(ref, n)
lintangsutawika's avatar
lintangsutawika committed
94
95
        for (ngram, count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
96
97
    return ([len(ref) for ref in refs], maxcounts)

lintangsutawika's avatar
lintangsutawika committed
98

lintangsutawika's avatar
lintangsutawika committed
99
def cook_test(test, item, n=4):
lintangsutawika's avatar
lintangsutawika committed
100
101
102
    """Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it."""
    (reflens, refmaxcounts) = item
lintangsutawika's avatar
lintangsutawika committed
103
104
105
106
107
    test = normalize(test)
    result = {}
    result["testlen"] = len(test)

    # Calculate effective reference sentence length.
lintangsutawika's avatar
lintangsutawika committed
108

lintangsutawika's avatar
lintangsutawika committed
109
110
111
    if eff_ref_len == "shortest":
        result["reflen"] = min(reflens)
    elif eff_ref_len == "average":
lintangsutawika's avatar
lintangsutawika committed
112
        result["reflen"] = float(sum(reflens)) / len(reflens)
lintangsutawika's avatar
lintangsutawika committed
113
114
115
    elif eff_ref_len == "closest":
        min_diff = None
        for reflen in reflens:
lintangsutawika's avatar
lintangsutawika committed
116
117
118
            if min_diff is None or abs(reflen - len(test)) < min_diff:
                min_diff = abs(reflen - len(test))
                result["reflen"] = reflen
lintangsutawika's avatar
lintangsutawika committed
119

lintangsutawika's avatar
lintangsutawika committed
120
    result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)]
lintangsutawika's avatar
lintangsutawika committed
121

lintangsutawika's avatar
lintangsutawika committed
122
    result["correct"] = [0] * n
lintangsutawika's avatar
lintangsutawika committed
123
124
    counts = count_ngrams(test, n)
    for (ngram, count) in counts.items():
lintangsutawika's avatar
lintangsutawika committed
125
        result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
126
127
128

    return result

lintangsutawika's avatar
lintangsutawika committed
129

lintangsutawika's avatar
lintangsutawika committed
130
def score_cooked(allcomps, n=4, ground=0, smooth=1):
lintangsutawika's avatar
lintangsutawika committed
131
    totalcomps = {"testlen": 0, "reflen": 0, "guess": [0] * n, "correct": [0] * n}
lintangsutawika's avatar
lintangsutawika committed
132
    for comps in allcomps:
lintangsutawika's avatar
lintangsutawika committed
133
        for key in ["testlen", "reflen"]:
lintangsutawika's avatar
lintangsutawika committed
134
            totalcomps[key] += comps[key]
lintangsutawika's avatar
lintangsutawika committed
135
        for key in ["guess", "correct"]:
lintangsutawika's avatar
lintangsutawika committed
136
137
138
139
140
            for k in range(n):
                totalcomps[key][k] += comps[key][k]
    logbleu = 0.0
    all_bleus = []
    for k in range(n):
lintangsutawika's avatar
lintangsutawika committed
141
142
143
144
145
146
147
148
149
150
151
152
        correct = totalcomps["correct"][k]
        guess = totalcomps["guess"][k]
        addsmooth = 0
        if smooth == 1 and k > 0:
            addsmooth = 1
        logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(
            guess + addsmooth + sys.float_info.min
        )
        if guess == 0:
            all_bleus.append(-10000000)
        else:
            all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess))
lintangsutawika's avatar
lintangsutawika committed
153
154
155
156

    logbleu /= float(n)
    all_bleus.insert(0, logbleu)

lintangsutawika's avatar
lintangsutawika committed
157
158
159
    brevPenalty = min(
        0, 1 - float(totalcomps["reflen"] + 1) / (totalcomps["testlen"] + 1)
    )
lintangsutawika's avatar
lintangsutawika committed
160
    for i in range(len(all_bleus)):
lintangsutawika's avatar
lintangsutawika committed
161
162
163
        if i == 0:
            all_bleus[i] += brevPenalty
        all_bleus[i] = math.exp(all_bleus[i])
lintangsutawika's avatar
lintangsutawika committed
164
165
    return all_bleus

lintangsutawika's avatar
lintangsutawika committed
166
167

def bleu(refs, candidate, ground=0, smooth=1):
lintangsutawika's avatar
lintangsutawika committed
168
169
170
171
    refs = cook_refs(refs)
    test = cook_test(candidate, refs)
    return score_cooked([test], ground=ground, smooth=smooth)

lintangsutawika's avatar
lintangsutawika committed
172

lintangsutawika's avatar
lintangsutawika committed
173
def splitPuncts(line):
lintangsutawika's avatar
lintangsutawika committed
174
175
    return " ".join(re.findall(r"[\w]+|[^\s\w]", line))

lintangsutawika's avatar
lintangsutawika committed
176
177

def computeMaps(predictions, goldfile):
lintangsutawika's avatar
lintangsutawika committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    predictionMap = {}
    goldMap = {}
    gf = open(goldfile, "r")

    for row in predictions:
        cols = row.strip().split("\t")
        if len(cols) == 1:
            (rid, pred) = (cols[0], "")
        else:
            (rid, pred) = (cols[0], cols[1])
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]

    for row in gf:
        (rid, pred) = row.split("\t")
        if rid in predictionMap:  # Only insert if the id exists for the method
            if rid not in goldMap:
                goldMap[rid] = []
            goldMap[rid].append(splitPuncts(pred.strip().lower()))

    sys.stderr.write("Total: " + str(len(goldMap)) + "\n")
    return (goldMap, predictionMap)


# m1 is the reference map
# m2 is the prediction map
lintangsutawika's avatar
lintangsutawika committed
203
def bleuFromMaps(m1, m2):
lintangsutawika's avatar
lintangsutawika committed
204
205
    score = [0] * 5
    num = 0.0
lintangsutawika's avatar
lintangsutawika committed
206

lintangsutawika's avatar
lintangsutawika committed
207
208
209
210
211
212
    for key in m1:
        if key in m2:
            bl = bleu(m1[key], m2[key][0])
            score = [score[i] + bl[i] for i in range(0, len(bl))]
            num += 1
    return [s * 100.0 / num for s in score]
lintangsutawika's avatar
lintangsutawika committed
213
214
215
216
217
218
219
220


def smoothed_bleu_4(references, predictions, **kwargs):

    predictionMap = {}
    goldMap = {}

    for rid, pred in enumerate(predictions):
lintangsutawika's avatar
lintangsutawika committed
221
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
222
223

    for rid, row in enumerate(references):
lintangsutawika's avatar
lintangsutawika committed
224
        goldMap[rid] = [splitPuncts(row.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
225
226
227

    return bleuFromMaps(goldMap, predictionMap)[0]

lintangsutawika's avatar
lintangsutawika committed
228
229
230
231
232
233
234
235

if __name__ == "__main__":
    reference_file = sys.argv[1]
    predictions = []
    for row in sys.stdin:
        predictions.append(row)
    (goldMap, predictionMap) = computeMaps(predictions, reference_file)
    print(bleuFromMaps(goldMap, predictionMap)[0])