"megatron/training/global_vars.py" did not exist on "3aca141586a4b8cdc983c3ecf5f7baf60506c7f8"
utils.py 1.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import re
import string
from collections import Counter


def normalize_answer(s):
    """
    Taken from the official evaluation script for v1.1 of the SQuAD dataset.
    Lower text and remove punctuation, articles and extra whitespace.
    """

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1(items):
    """
    Taken from the official evaluation script for v1.1 of the SQuAD dataset.
    """

    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]

    f1_list = []

    for i in range(len(golds)):
        prediction_tokens = normalize_answer(preds[i]).split()
        references_tokens = normalize_answer(golds[i]).split()
        common = Counter(prediction_tokens) & Counter(references_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            f1_score = 0
        else:
            precision = 1.0 * num_same / len(prediction_tokens)
            recall = 1.0 * num_same / len(references_tokens)
            f1_score = (2 * precision * recall) / (precision + recall)

        f1_list.append(f1_score)

    return sum(f1_list) / len(f1_list)