utils.py 7.62 KB
Newer Older
Baber's avatar
Baber committed
1
import re
Baber's avatar
Baber committed
2
from typing import Optional
Baber's avatar
Baber committed
3

Baber's avatar
nits  
Baber committed
4
# from api_model import make_concurrent_requests
Baber's avatar
Baber committed
5
6
7
from Levenshtein import distance


Baber's avatar
Baber committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# required for external LM call

DEMO_PROMPT = """
Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.

Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
Question: Which number is missing?

Model response: The number missing in the sequence is 14.

Extracted answer: 14

Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
Question: What is the fraction of females facing the camera?

Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.

Extracted answer: 0.6

Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)

Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.

Extracted answer: 1.45

Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
Question: Between which two years does the line  graph saw its maximum peak?

Model response: The line graph saw its maximum peak between 2007 and 2008.

Extracted answer: [2007, 2008]

Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5

Model response: The correct answer is (B) 8/11.

Extracted answer: B
"""


def create_test_prompt(demo_prompt, query, response):
    demo_prompt = demo_prompt.strip()
    test_prompt = f"{query}\n\n{response}"
    full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
    return full_prompt


Baber's avatar
nits  
Baber committed
57
58
59
60
61
62
63
def verify_extraction(extraction):
    extraction = extraction.strip()
    if not extraction:
        return False
    return True


Baber's avatar
Baber committed
64
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
Baber's avatar
Baber committed
65
def get_most_similar(prediction: str, choices: list) -> float:
Baber's avatar
Baber committed
66
67
68
69
70
71
72
73
74
    """
    Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
    """
    distances = [distance(prediction, choice) for choice in choices]
    ind = distances.index(min(distances))
    return choices[ind]
    # return min(choices, key=lambda choice: distance(prediction, choice))


Baber's avatar
Baber committed
75
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
Baber's avatar
Baber committed
76
def normalize_extracted_answer(
Baber's avatar
Baber committed
77
    extraction: str,
Baber's avatar
Baber committed
78
79
80
81
    choices: list,
    question_type: str,
    answer_type: str,
    precision,
Baber's avatar
Baber committed
82
83
    ignore_empty_extractions=True,
) -> Optional[str]:
Baber's avatar
Baber committed
84
85
86
    """
    Normalize the extracted answer to match the answer type
    """
Baber's avatar
Baber committed
87

Baber's avatar
Baber committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    if question_type == "multi_choice":
        # make sure the extraction is a string
        if isinstance(extraction, str):
            extraction = extraction.strip()
        else:
            try:
                extraction = str(extraction)
            except Exception:
                extraction = ""

        # if the extraction is empty, return None
        if ignore_empty_extractions and not extraction:
            return None

        # extract "A" from "(A) text"
        letter = re.findall(r"\(([a-zA-Z])\)", extraction)
        if len(letter) > 0:
            extraction = letter[0].upper()

        sequential_characters = [chr(ord("A") + i) for i in range(len(choices))]

        # if model output a character, use it as index of available choices
        if extraction in sequential_characters:
            option_index = sequential_characters.index(extraction)
            normalized_extraction = choices[option_index]
        else:
            # select the most similar option
            normalized_extraction = get_most_similar(extraction, choices)
        assert normalized_extraction in choices

    elif answer_type == "integer":
        try:
            normalized_extraction = str(int(float(extraction)))
        except Exception:
            normalized_extraction = None

    elif answer_type == "float":
        try:
Baber's avatar
Baber committed
126
            normalized_extraction = str(round(float(extraction), int(precision)))
Baber's avatar
Baber committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        except Exception:
            normalized_extraction = None

    elif answer_type == "list":
        try:
            normalized_extraction = str(extraction)
        except Exception:
            normalized_extraction = None

    return normalized_extraction


def safe_equal(prediction, answer):
    """
    Check if the prediction is equal to the answer, even if they are of different types
    """
    try:
        if prediction == answer:
            return True
        return False
    except Exception:
        return False


Baber's avatar
nits  
Baber committed
151
def extract_answer(response: str, problem: dict, quick_extract=True) -> str:
Baber's avatar
Baber committed
152
153
154
155
    question_type = problem["question_type"]
    answer_type = problem["answer_type"]
    choices = problem["choices"]
    # query = problem["query"]
Baber's avatar
nits  
Baber committed
156
    # pid = problem["pid"]
Baber's avatar
Baber committed
157

Baber's avatar
Baber committed
158
159
    if response == "":
        return ""
Baber's avatar
Baber committed
160
161

    if question_type == "multi_choice" and response in choices:
Baber's avatar
Baber committed
162
163
        return response

Baber's avatar
Baber committed
164
165
166
167
168
169
    if answer_type == "integer":
        try:
            extraction = int(response)
            return str(extraction)
        except Exception:
            pass
Baber's avatar
Baber committed
170

Baber's avatar
Baber committed
171
172
173
174
175
176
    if answer_type == "float":
        try:
            extraction = str(float(response))
            return extraction
        except Exception:
            pass
Baber's avatar
Baber committed
177

Baber's avatar
nits  
Baber committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    # quick extraction
    if quick_extract:
        # The answer is "text". -> "text"
        try:
            result = re.search(r'The answer is "(.*)"\.', response)
            if result:
                extraction = result.group(1)
                return extraction
        except Exception:
            pass

    # general extraction
    # try:
    #     full_prompt = create_test_prompt(DEMO_PROMPT, query, response)
    #     extraction = make_concurrent_requests(full_prompt)
    #     return extraction
    # except Exception:
    #     print(
    #         f"Error in extracting answer for problem: {pid} with response: {response}"
    #     )
    #     # logging.info(f"Error in extracting answer for problem: {pid} with response: {response}")
    #     # logging.info(e)

    return ""


def extract_all_answers(
    resps: list[list[str]], docs: dict, quick_extract=True
) -> list[str]:
    return [
        extract_answer(resp[0], doc, quick_extract=quick_extract)
        for resp, doc in zip(resps, docs)
    ]
Baber's avatar
Baber committed
211
212
213
214
215
216
217
218
219
220


# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc: dict, results: list[str]):
    response = results[0]  # noqa: F841
    choices = doc["choices"]
    question_type = doc["question_type"]
    answer_type = doc["answer_type"]
    precision = doc["precision"]  # noqa: F841
    answer = doc["answer"]
Baber's avatar
nits  
Baber committed
221
222
223
224
225
226
227
228
229
230
    # step 1: extract the answer from the model response
    # extracted_answer = extract_answer(response, doc)
    extracted_answer = response[0]
    if verify_extraction(extracted_answer):
        normalized_extraction = normalize_extracted_answer(
            extracted_answer, choices, question_type, answer_type, precision
        )
        res = safe_equal(normalized_extraction, answer)
    else:
        res = False
Baber's avatar
Baber committed
231
    return {"acc": 1.0} if res else {"acc": 0.0}
Baber's avatar
Baber committed
232
233
234
235
236
237
238


### MathVista MCQ ###


def process_docs_mcq(dataset):
    return dataset.filter(lambda x: x["question_type"] == "multi_choice")