utils.py 4.38 KB
Newer Older
Baber's avatar
Baber committed
1
import re
Baber's avatar
Baber committed
2
from typing import Optional
Baber's avatar
Baber committed
3
4
5
6
7

from Levenshtein import distance


# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
Baber's avatar
Baber committed
8
def get_most_similar(prediction: str, choices: list) -> float:
Baber's avatar
Baber committed
9
10
11
12
13
14
15
16
17
    """
    Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
    """
    distances = [distance(prediction, choice) for choice in choices]
    ind = distances.index(min(distances))
    return choices[ind]
    # return min(choices, key=lambda choice: distance(prediction, choice))


Baber's avatar
Baber committed
18
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
Baber's avatar
Baber committed
19
def normalize_extracted_answer(
Baber's avatar
Baber committed
20
    extraction: str,
Baber's avatar
Baber committed
21
22
23
24
    choices: list,
    question_type: str,
    answer_type: str,
    precision,
Baber's avatar
Baber committed
25
26
    ignore_empty_extractions=True,
) -> Optional[str]:
Baber's avatar
Baber committed
27
28
29
    """
    Normalize the extracted answer to match the answer type
    """
Baber's avatar
Baber committed
30

Baber's avatar
Baber committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    if question_type == "multi_choice":
        # make sure the extraction is a string
        if isinstance(extraction, str):
            extraction = extraction.strip()
        else:
            try:
                extraction = str(extraction)
            except Exception:
                extraction = ""

        # if the extraction is empty, return None
        if ignore_empty_extractions and not extraction:
            return None

        # extract "A" from "(A) text"
        letter = re.findall(r"\(([a-zA-Z])\)", extraction)
        if len(letter) > 0:
            extraction = letter[0].upper()

        sequential_characters = [chr(ord("A") + i) for i in range(len(choices))]

        # if model output a character, use it as index of available choices
        if extraction in sequential_characters:
            option_index = sequential_characters.index(extraction)
            normalized_extraction = choices[option_index]
        else:
            # select the most similar option
            normalized_extraction = get_most_similar(extraction, choices)
        assert normalized_extraction in choices

    elif answer_type == "integer":
        try:
            normalized_extraction = str(int(float(extraction)))
        except Exception:
            normalized_extraction = None

    elif answer_type == "float":
        try:
Baber's avatar
Baber committed
69
            normalized_extraction = str(round(float(extraction), int(precision)))
Baber's avatar
Baber committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        except Exception:
            normalized_extraction = None

    elif answer_type == "list":
        try:
            normalized_extraction = str(extraction)
        except Exception:
            normalized_extraction = None

    return normalized_extraction


def safe_equal(prediction, answer):
    """
    Check if the prediction is equal to the answer, even if they are of different types
    """
    try:
        if prediction == answer:
            return True
        return False
    except Exception:
        return False


Baber's avatar
Baber committed
94
95
96
97
98
99
def extract_answer(response: str, problem: dict) -> str:
    question_type = problem["question_type"]
    answer_type = problem["answer_type"]
    choices = problem["choices"]
    # query = problem["query"]
    # pid = problem['pid']
Baber's avatar
Baber committed
100

Baber's avatar
Baber committed
101
102
    if response == "":
        return ""
Baber's avatar
Baber committed
103

Baber's avatar
Baber committed
104
    ### This is not in the original code:
Baber's avatar
Baber committed
105
106
107
    extract = re.findall(
        r"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)", response
    )
Baber's avatar
Baber committed
108
109
110
111
    if extract:
        return str(extract[0])
    ###

Baber's avatar
Baber committed
112
    if question_type == "multi_choice" and response in choices:
Baber's avatar
Baber committed
113
114
        return response

Baber's avatar
Baber committed
115
116
117
118
119
120
    if answer_type == "integer":
        try:
            extraction = int(response)
            return str(extraction)
        except Exception:
            pass
Baber's avatar
Baber committed
121

Baber's avatar
Baber committed
122
123
124
125
126
127
    if answer_type == "float":
        try:
            extraction = str(float(response))
            return extraction
        except Exception:
            pass
Baber's avatar
Baber committed
128

Baber's avatar
Baber committed
129
    return response
Baber's avatar
Baber committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc: dict, results: list[str]):
    response = results[0]  # noqa: F841
    choices = doc["choices"]
    question_type = doc["question_type"]
    answer_type = doc["answer_type"]
    precision = doc["precision"]  # noqa: F841
    answer = doc["answer"]
    extracted_answer = extract_answer(response, doc)
    normalized_extraction = normalize_extracted_answer(
        extracted_answer, choices, question_type, answer_type, precision
    )
    res = safe_equal(normalized_extraction, answer)
    return {"acc": 1.0} if res else {"acc": 0.0}