utils.py 6.45 KB
Newer Older
Baber's avatar
Baber committed
1
import re
Baber's avatar
Baber committed
2
from typing import Optional
Baber's avatar
Baber committed
3
4
5
6

from Levenshtein import distance


Baber's avatar
Baber committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# required for external LM call

DEMO_PROMPT = """
Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.

Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
Question: Which number is missing?

Model response: The number missing in the sequence is 14.

Extracted answer: 14

Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
Question: What is the fraction of females facing the camera?

Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.

Extracted answer: 0.6

Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)

Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.

Extracted answer: 1.45

Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
Question: Between which two years does the line  graph saw its maximum peak?

Model response: The line graph saw its maximum peak between 2007 and 2008.

Extracted answer: [2007, 2008]

Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5

Model response: The correct answer is (B) 8/11.

Extracted answer: B
"""


def create_test_prompt(demo_prompt, query, response):
    demo_prompt = demo_prompt.strip()
    test_prompt = f"{query}\n\n{response}"
    full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
    return full_prompt


Baber's avatar
Baber committed
56
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
Baber's avatar
Baber committed
57
def get_most_similar(prediction: str, choices: list) -> float:
Baber's avatar
Baber committed
58
59
60
61
62
63
64
65
66
    """
    Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
    """
    distances = [distance(prediction, choice) for choice in choices]
    ind = distances.index(min(distances))
    return choices[ind]
    # return min(choices, key=lambda choice: distance(prediction, choice))


Baber's avatar
Baber committed
67
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
Baber's avatar
Baber committed
68
def normalize_extracted_answer(
Baber's avatar
Baber committed
69
    extraction: str,
Baber's avatar
Baber committed
70
71
72
73
    choices: list,
    question_type: str,
    answer_type: str,
    precision,
Baber's avatar
Baber committed
74
75
    ignore_empty_extractions=True,
) -> Optional[str]:
Baber's avatar
Baber committed
76
77
78
    """
    Normalize the extracted answer to match the answer type
    """
Baber's avatar
Baber committed
79

Baber's avatar
Baber committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    if question_type == "multi_choice":
        # make sure the extraction is a string
        if isinstance(extraction, str):
            extraction = extraction.strip()
        else:
            try:
                extraction = str(extraction)
            except Exception:
                extraction = ""

        # if the extraction is empty, return None
        if ignore_empty_extractions and not extraction:
            return None

        # extract "A" from "(A) text"
        letter = re.findall(r"\(([a-zA-Z])\)", extraction)
        if len(letter) > 0:
            extraction = letter[0].upper()

        sequential_characters = [chr(ord("A") + i) for i in range(len(choices))]

        # if model output a character, use it as index of available choices
        if extraction in sequential_characters:
            option_index = sequential_characters.index(extraction)
            normalized_extraction = choices[option_index]
        else:
            # select the most similar option
            normalized_extraction = get_most_similar(extraction, choices)
        assert normalized_extraction in choices

    elif answer_type == "integer":
        try:
            normalized_extraction = str(int(float(extraction)))
        except Exception:
            normalized_extraction = None

    elif answer_type == "float":
        try:
Baber's avatar
Baber committed
118
            normalized_extraction = str(round(float(extraction), int(precision)))
Baber's avatar
Baber committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        except Exception:
            normalized_extraction = None

    elif answer_type == "list":
        try:
            normalized_extraction = str(extraction)
        except Exception:
            normalized_extraction = None

    return normalized_extraction


def safe_equal(prediction, answer):
    """
    Check if the prediction is equal to the answer, even if they are of different types
    """
    try:
        if prediction == answer:
            return True
        return False
    except Exception:
        return False


Baber's avatar
Baber committed
143
144
145
146
147
148
def extract_answer(response: str, problem: dict) -> str:
    question_type = problem["question_type"]
    answer_type = problem["answer_type"]
    choices = problem["choices"]
    # query = problem["query"]
    # pid = problem['pid']
Baber's avatar
Baber committed
149

Baber's avatar
Baber committed
150
151
    if response == "":
        return ""
Baber's avatar
Baber committed
152

Baber's avatar
Baber committed
153
    ### This is not in the original code:
Baber's avatar
Baber committed
154
155
156
    extract = re.findall(
        r"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)", response
    )
Baber's avatar
Baber committed
157
158
159
160
    if extract:
        return str(extract[0])
    ###

Baber's avatar
Baber committed
161
    if question_type == "multi_choice" and response in choices:
Baber's avatar
Baber committed
162
163
        return response

Baber's avatar
Baber committed
164
165
166
167
168
169
    if answer_type == "integer":
        try:
            extraction = int(response)
            return str(extraction)
        except Exception:
            pass
Baber's avatar
Baber committed
170

Baber's avatar
Baber committed
171
172
173
174
175
176
    if answer_type == "float":
        try:
            extraction = str(float(response))
            return extraction
        except Exception:
            pass
Baber's avatar
Baber committed
177

Baber's avatar
Baber committed
178
    return response
Baber's avatar
Baber committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194


# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc: dict, results: list[str]):
    response = results[0]  # noqa: F841
    choices = doc["choices"]
    question_type = doc["question_type"]
    answer_type = doc["answer_type"]
    precision = doc["precision"]  # noqa: F841
    answer = doc["answer"]
    extracted_answer = extract_answer(response, doc)
    normalized_extraction = normalize_extracted_answer(
        extracted_answer, choices, question_type, answer_type, precision
    )
    res = safe_equal(normalized_extraction, answer)
    return {"acc": 1.0} if res else {"acc": 0.0}
Baber's avatar
Baber committed
195
196
197
198
199
200
201


### MathVista MCQ ###


def process_docs_mcq(dataset):
    return dataset.filter(lambda x: x["question_type"] == "multi_choice")