script.py 3.94 KB
Newer Older
am-bean's avatar
am-bean committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import ast
import re
import unicodedata as ud


def clean_answer(answer: str):
    # remove whitespace and final stop
    clean = answer.strip().strip(".")

    # reduce multiple spaces to a single space
    clean = re.sub(r"[ ]+", " ", clean)

    # reduce to lower case
    clean = clean.lower()

    # remove internal + (can't currently handle for marking)
    clean = re.sub("\\+", "", clean)

    # make quotes consistent
    quotes_map = {"‘": "'", "’": "'", "“": '"', "”": '"'}

    for k, v in quotes_map.items():
        clean = re.sub(k, v, clean)

    # make unicode consistent
    clean = ud.normalize("NFKD", clean)

    return clean


def safe_exact(references: list[str], predictions: list[str]):
    if len(references[0]) == 0:
        return 1.0
    if len(predictions[0]) == 0:
        return 0.0

    score = float(references[0] == predictions[0])

    return score


def parse_str_list_score(model, correct, scoring_func):
    model = str(model)
    if len(correct) == 0:
        return 1.0
    if len(model) == 0:
        return 0.0
48
49
50
51
    if ("[" in correct) and (("'" in correct) or ('"' in correct)):
        readstr = ast.literal_eval(correct)
        if isinstance(readstr, list):
            correct = readstr
am-bean's avatar
am-bean committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    if isinstance(correct, list):
        if all(isinstance(c, str) for c in correct):
            max_score = 0.0
            if (
                len(correct) > 24
            ):  # bleu and rouge are expensive and don't make sense for any order problems
                return clean_answer(model) in [clean_answer(c) for c in correct]
            for c in correct:
                score = scoring_func(
                    references=[clean_answer(c)],
                    predictions=[clean_answer(model)],
                )
                if score > max_score:
                    max_score = score
            return max_score
        else:
            max_score = 0.0
            for c in correct:
                if isinstance(c, list):
                    c = ", ".join(c)
                    score = scoring_func(
                        references=[clean_answer(c)],
                        predictions=[clean_answer(model)],
                    )
                else:
                    score = scoring_func(
                        references=[clean_answer(c)],
                        predictions=[clean_answer(model)],
                    )
                if score > max_score:
                    max_score = score
            return max_score
    else:
        return scoring_func(
            references=[clean_answer(correct)],
            predictions=[clean_answer(model)],
        )


91
92
def exact_match(references: list[str], predictions: list[str]):
    ref_dict = ast.literal_eval(references[0])
am-bean's avatar
am-bean committed
93
    try:
94
95
96
97
98
99
100
101
        assert "{" in predictions[0]
        if predictions[0][-1] == "}":
            pred_dict = ast.literal_eval(predictions[0][predictions[0].index("{") :])
        else:
            pred_dict = ast.literal_eval(
                predictions[0][predictions[0].index("{") :] + "}"
            )
    except (SyntaxError, ValueError, AssertionError):
am-bean's avatar
am-bean committed
102
103
        pred_dict = {}
        for k in ref_dict.keys():
104
105
            m = re.search(re.escape(str(k)) + """': ([^']+)'[,\\}]""", predictions[0])
            n = re.search(re.escape(str(k)) + """": ([^"]+)"[,\\}]""", predictions[0])
am-bean's avatar
am-bean committed
106
107
            if m:
                pred_dict[k] = m.group()[:-1]
108
109
            elif n:
                pred_dict[k] = n.group()[:-1]
am-bean's avatar
am-bean committed
110
111
112
113
114
            else:
                pred_dict[k] = ""
    pred_dict_full = {
        k: pred_dict[k] if k in pred_dict else "" for k in ref_dict.keys()
    }
115

am-bean's avatar
am-bean committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    scores = [
        parse_str_list_score(pred_dict_full[k], v, safe_exact)
        for k, v in ref_dict.items()
    ]

    return scores


def aggregate_scores(input):
    return sum([sum(i) for i in input]) / sum([len(j) for j in input])


def aggregate_metrics(
    metrics_scores: list[int], dataset_size: list[int], weight_by_size: bool
):
    return metrics_scores[0] - metrics_scores[1]