utils.py 1.5 KB
Newer Older
1
2
3
4
import re
import string
import collections

lintangsutawika's avatar
format  
lintangsutawika committed
5

6
7
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
lintangsutawika's avatar
format  
lintangsutawika committed
8

9
    def remove_articles(text):
lintangsutawika's avatar
format  
lintangsutawika committed
10
11
12
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

13
    def white_space_fix(text):
lintangsutawika's avatar
format  
lintangsutawika committed
14
15
        return " ".join(text.split())

16
17
    def remove_punc(text):
        exclude = set(string.punctuation)
lintangsutawika's avatar
format  
lintangsutawika committed
18
19
        return "".join(ch for ch in text if ch not in exclude)

20
21
    def lower(text):
        return text.lower()
lintangsutawika's avatar
format  
lintangsutawika committed
22

23
24
    return white_space_fix(remove_articles(remove_punc(lower(s))))

lintangsutawika's avatar
format  
lintangsutawika committed
25

26
def get_tokens(s):
lintangsutawika's avatar
format  
lintangsutawika committed
27
28
    if not s:
        return []
29
    return normalize_answer(s).split()
lintangsutawika's avatar
lintangsutawika committed
30

lintangsutawika's avatar
format  
lintangsutawika committed
31

lintangsutawika's avatar
lintangsutawika committed
32
# Exact match (the normalized answer exactly match the gold answer)
33
34
def exact(predictions, references):
    return int(normalize_answer(references[0]) == normalize_answer(predictions[0]))
lintangsutawika's avatar
lintangsutawika committed
35

lintangsutawika's avatar
format  
lintangsutawika committed
36

lintangsutawika's avatar
lintangsutawika committed
37
38
# The F-score of predicted tokens versus the gold answer
def f1(predictions, references):
39
40
41
42
43
44
45
46
47
48
49
50
51
    gold_toks = get_tokens(references[0])
    pred_toks = get_tokens(predictions[0])
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1