bleu.py 7.64 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
#!/usr/bin/python
2
3
4
5
6
7
import re
import sys
import math
import xml.sax.saxutils

from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
lintangsutawika's avatar
lintangsutawika committed
8

lintangsutawika's avatar
lintangsutawika committed
9
10
11
"""
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
"""
lintangsutawika's avatar
lintangsutawika committed
12
13
14

# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $

lintangsutawika's avatar
lintangsutawika committed
15
"""Provides:
lintangsutawika's avatar
lintangsutawika committed
16
17
18
19
20
21
22
23

cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
score_cooked(alltest, n=4): Score a list of cooked test sentences.

score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.

The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
lintangsutawika's avatar
lintangsutawika committed
24
"""
lintangsutawika's avatar
lintangsutawika committed
25
26
27
28
29
30
31

# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
nonorm = 0

preserve_case = False
eff_ref_len = "shortest"

32
normalize1: List[Tuple[Union[Pattern[str], str], str]] = [
lintangsutawika's avatar
lintangsutawika committed
33
34
35
36
    ("<skipped>", ""),  # strip "skipped" tags
    (r"-\n", ""),  # strip end-of-line hyphenation and join lines
    (r"\n", " "),  # join lines
    #    (r'(\d)\s+(?=\d)', r'\1'), # join digits
lintangsutawika's avatar
lintangsutawika committed
37
38
39
]
normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]

40
normalize2: List[Tuple[Union[Pattern[str], str], str]] = [
lintangsutawika's avatar
lintangsutawika committed
41
42
43
44
45
46
47
48
49
50
51
52
53
    (
        r"([\{-\~\[-\` -\&\(-\+\:-\@\/])",
        r" \1 ",
    ),  # tokenize punctuation. apostrophe is missing
    (
        r"([^0-9])([\.,])",
        r"\1 \2 ",
    ),  # tokenize period and comma unless preceded by a digit
    (
        r"([\.,])([^0-9])",
        r" \1 \2",
    ),  # tokenize period and comma unless followed by a digit
    (r"([0-9])(-)", r"\1 \2 "),  # tokenize dash when preceded by a digit
lintangsutawika's avatar
lintangsutawika committed
54
55
56
]
normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]

lintangsutawika's avatar
lintangsutawika committed
57

lintangsutawika's avatar
lintangsutawika committed
58
def normalize(s):
lintangsutawika's avatar
lintangsutawika committed
59
    """Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl."""
lintangsutawika's avatar
lintangsutawika committed
60
    # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
lintangsutawika's avatar
lintangsutawika committed
61
    if nonorm:
lintangsutawika's avatar
lintangsutawika committed
62
63
64
65
        return s.split()
    if type(s) is not str:
        s = " ".join(s)
    # language-independent part:
66
    for pattern, replace in normalize1:
lintangsutawika's avatar
lintangsutawika committed
67
        s = re.sub(pattern, replace, s)
lintangsutawika's avatar
lintangsutawika committed
68
    s = xml.sax.saxutils.unescape(s, {"&quot;": '"'})
lintangsutawika's avatar
lintangsutawika committed
69
70
71
    # language-dependent part (assuming Western languages):
    s = " %s " % s
    if not preserve_case:
lintangsutawika's avatar
lintangsutawika committed
72
        s = s.lower()  # this might not be identical to the original
73
    for pattern, replace in normalize2:
lintangsutawika's avatar
lintangsutawika committed
74
75
76
        s = re.sub(pattern, replace, s)
    return s.split()

lintangsutawika's avatar
lintangsutawika committed
77

lintangsutawika's avatar
lintangsutawika committed
78
def count_ngrams(words, n=4):
79
    counts: Dict[Any, int] = {}
lintangsutawika's avatar
lintangsutawika committed
80
81
82
83
    for k in range(1, n + 1):
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i : i + k])
            counts[ngram] = counts.get(ngram, 0) + 1
lintangsutawika's avatar
lintangsutawika committed
84
85
    return counts

lintangsutawika's avatar
lintangsutawika committed
86

lintangsutawika's avatar
lintangsutawika committed
87
def cook_refs(refs, n=4):
lintangsutawika's avatar
lintangsutawika committed
88
    """Takes a list of reference sentences for a single segment
lintangsutawika's avatar
lintangsutawika committed
89
    and returns an object that encapsulates everything that BLEU
lintangsutawika's avatar
lintangsutawika committed
90
91
    needs to know about them."""

lintangsutawika's avatar
lintangsutawika committed
92
    refs = [normalize(ref) for ref in refs]
93
    maxcounts: Dict[Tuple[str], int] = {}
lintangsutawika's avatar
lintangsutawika committed
94
95
    for ref in refs:
        counts = count_ngrams(ref, n)
96
        for ngram, count in counts.items():
lintangsutawika's avatar
lintangsutawika committed
97
            maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
98
99
    return ([len(ref) for ref in refs], maxcounts)

lintangsutawika's avatar
lintangsutawika committed
100

lintangsutawika's avatar
lintangsutawika committed
101
def cook_test(test, item, n=4):
lintangsutawika's avatar
lintangsutawika committed
102
103
104
    """Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it."""
    (reflens, refmaxcounts) = item
lintangsutawika's avatar
lintangsutawika committed
105
    test = normalize(test)
106
    result: Dict[str, Any] = {}
lintangsutawika's avatar
lintangsutawika committed
107
108
109
    result["testlen"] = len(test)

    # Calculate effective reference sentence length.
lintangsutawika's avatar
lintangsutawika committed
110

lintangsutawika's avatar
lintangsutawika committed
111
112
113
    if eff_ref_len == "shortest":
        result["reflen"] = min(reflens)
    elif eff_ref_len == "average":
lintangsutawika's avatar
lintangsutawika committed
114
        result["reflen"] = float(sum(reflens)) / len(reflens)
lintangsutawika's avatar
lintangsutawika committed
115
    elif eff_ref_len == "closest":
116
        min_diff: Optional[int] = None
lintangsutawika's avatar
lintangsutawika committed
117
        for reflen in reflens:
lintangsutawika's avatar
lintangsutawika committed
118
119
120
            if min_diff is None or abs(reflen - len(test)) < min_diff:
                min_diff = abs(reflen - len(test))
                result["reflen"] = reflen
lintangsutawika's avatar
lintangsutawika committed
121

lintangsutawika's avatar
lintangsutawika committed
122
    result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)]
lintangsutawika's avatar
lintangsutawika committed
123

lintangsutawika's avatar
lintangsutawika committed
124
    result["correct"] = [0] * n
lintangsutawika's avatar
lintangsutawika committed
125
    counts = count_ngrams(test, n)
126
    for ngram, count in counts.items():
lintangsutawika's avatar
lintangsutawika committed
127
        result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
lintangsutawika's avatar
lintangsutawika committed
128
129
130

    return result

lintangsutawika's avatar
lintangsutawika committed
131

lintangsutawika's avatar
lintangsutawika committed
132
def score_cooked(allcomps, n=4, ground=0, smooth=1):
133
134
135
136
137
138
    totalcomps: Dict[str, Any] = {
        "testlen": 0,
        "reflen": 0,
        "guess": [0] * n,
        "correct": [0] * n,
    }
lintangsutawika's avatar
lintangsutawika committed
139
    for comps in allcomps:
lintangsutawika's avatar
lintangsutawika committed
140
        for key in ["testlen", "reflen"]:
lintangsutawika's avatar
lintangsutawika committed
141
            totalcomps[key] += comps[key]
lintangsutawika's avatar
lintangsutawika committed
142
        for key in ["guess", "correct"]:
lintangsutawika's avatar
lintangsutawika committed
143
144
145
            for k in range(n):
                totalcomps[key][k] += comps[key][k]
    logbleu = 0.0
146
    all_bleus: List[float] = []
lintangsutawika's avatar
lintangsutawika committed
147
    for k in range(n):
lintangsutawika's avatar
lintangsutawika committed
148
149
150
151
152
153
154
155
156
        correct = totalcomps["correct"][k]
        guess = totalcomps["guess"][k]
        addsmooth = 0
        if smooth == 1 and k > 0:
            addsmooth = 1
        logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(
            guess + addsmooth + sys.float_info.min
        )
        if guess == 0:
157
            all_bleus.append(-10000000.0)
lintangsutawika's avatar
lintangsutawika committed
158
159
        else:
            all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess))
lintangsutawika's avatar
lintangsutawika committed
160
161
162
163

    logbleu /= float(n)
    all_bleus.insert(0, logbleu)

lintangsutawika's avatar
lintangsutawika committed
164
165
166
    brevPenalty = min(
        0, 1 - float(totalcomps["reflen"] + 1) / (totalcomps["testlen"] + 1)
    )
lintangsutawika's avatar
lintangsutawika committed
167
    for i in range(len(all_bleus)):
lintangsutawika's avatar
lintangsutawika committed
168
169
170
        if i == 0:
            all_bleus[i] += brevPenalty
        all_bleus[i] = math.exp(all_bleus[i])
lintangsutawika's avatar
lintangsutawika committed
171
172
    return all_bleus

lintangsutawika's avatar
lintangsutawika committed
173
174

def bleu(refs, candidate, ground=0, smooth=1):
lintangsutawika's avatar
lintangsutawika committed
175
176
177
178
    refs = cook_refs(refs)
    test = cook_test(candidate, refs)
    return score_cooked([test], ground=ground, smooth=smooth)

lintangsutawika's avatar
lintangsutawika committed
179

lintangsutawika's avatar
lintangsutawika committed
180
def splitPuncts(line):
lintangsutawika's avatar
lintangsutawika committed
181
182
    return " ".join(re.findall(r"[\w]+|[^\s\w]", line))

lintangsutawika's avatar
lintangsutawika committed
183
184

def computeMaps(predictions, goldfile):
185
186
    predictionMap: Dict[str, list] = {}
    goldMap: Dict[str, list] = {}
lintangsutawika's avatar
lintangsutawika committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    gf = open(goldfile, "r")

    for row in predictions:
        cols = row.strip().split("\t")
        if len(cols) == 1:
            (rid, pred) = (cols[0], "")
        else:
            (rid, pred) = (cols[0], cols[1])
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]

    for row in gf:
        (rid, pred) = row.split("\t")
        if rid in predictionMap:  # Only insert if the id exists for the method
            if rid not in goldMap:
                goldMap[rid] = []
            goldMap[rid].append(splitPuncts(pred.strip().lower()))

    sys.stderr.write("Total: " + str(len(goldMap)) + "\n")
    return (goldMap, predictionMap)


# m1 is the reference map
# m2 is the prediction map
lintangsutawika's avatar
lintangsutawika committed
210
def bleuFromMaps(m1, m2):
lintangsutawika's avatar
lintangsutawika committed
211
212
    score = [0] * 5
    num = 0.0
lintangsutawika's avatar
lintangsutawika committed
213

lintangsutawika's avatar
lintangsutawika committed
214
215
216
217
218
219
    for key in m1:
        if key in m2:
            bl = bleu(m1[key], m2[key][0])
            score = [score[i] + bl[i] for i in range(0, len(bl))]
            num += 1
    return [s * 100.0 / num for s in score]
lintangsutawika's avatar
lintangsutawika committed
220
221
222
223
224
225
226


def smoothed_bleu_4(references, predictions, **kwargs):
    predictionMap = {}
    goldMap = {}

    for rid, pred in enumerate(predictions):
lintangsutawika's avatar
lintangsutawika committed
227
        predictionMap[rid] = [splitPuncts(pred.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
228
229

    for rid, row in enumerate(references):
lintangsutawika's avatar
lintangsutawika committed
230
        goldMap[rid] = [splitPuncts(row.strip().lower())]
lintangsutawika's avatar
lintangsutawika committed
231
232
233

    return bleuFromMaps(goldMap, predictionMap)[0]

lintangsutawika's avatar
lintangsutawika committed
234
235
236
237
238
239
240
241

if __name__ == "__main__":
    reference_file = sys.argv[1]
    predictions = []
    for row in sys.stdin:
        predictions.append(row)
    (goldMap, predictionMap) = computeMaps(predictions, reference_file)
    print(bleuFromMaps(goldMap, predictionMap)[0])