utils.py 4.33 KB
Newer Older
Baber's avatar
Baber committed
1
import re
Baber's avatar
Baber committed
2
from typing import Optional
Baber's avatar
Baber committed
3
4
5
6
7

from Levenshtein import distance


# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
Baber's avatar
Baber committed
8
def get_most_similar(prediction: str, choices: list) -> float:
Baber's avatar
Baber committed
9
10
11
12
13
14
15
16
17
    """
    Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
    """
    distances = [distance(prediction, choice) for choice in choices]
    ind = distances.index(min(distances))
    return choices[ind]
    # return min(choices, key=lambda choice: distance(prediction, choice))


Baber's avatar
Baber committed
18
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
Baber's avatar
Baber committed
19
def normalize_extracted_answer(
Baber's avatar
Baber committed
20
    extraction: str,
Baber's avatar
Baber committed
21
22
23
24
    choices: list,
    question_type: str,
    answer_type: str,
    precision,
Baber's avatar
Baber committed
25
26
    ignore_empty_extractions=True,
) -> Optional[str]:
Baber's avatar
Baber committed
27
28
29
    """
    Normalize the extracted answer to match the answer type
    """
Baber's avatar
Baber committed
30

Baber's avatar
Baber committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if question_type == "multi_choice":
        # make sure the extraction is a string
        if isinstance(extraction, str):
            extraction = extraction.strip()
        else:
            try:
                extraction = str(extraction)
            except Exception:
                extraction = ""

        # if the extraction is empty, return None
        if ignore_empty_extractions and not extraction:
            return None

        # extract "A" from "(A) text"
        letter = re.findall(r"\(([a-zA-Z])\)", extraction)
        if len(letter) > 0:
            extraction = letter[0].upper()

        sequential_characters = [chr(ord("A") + i) for i in range(len(choices))]

        # if model output a character, use it as index of available choices
        if extraction in sequential_characters:
            option_index = sequential_characters.index(extraction)
            normalized_extraction = choices[option_index]
        else:
            # select the most similar option
            normalized_extraction = get_most_similar(extraction, choices)
        assert normalized_extraction in choices

    elif answer_type == "integer":
        try:
            normalized_extraction = str(int(float(extraction)))
        except Exception:
            normalized_extraction = None

    elif answer_type == "float":
        try:
            normalized_extraction = str(round(float(extraction), precision))
        except Exception:
            normalized_extraction = None

    elif answer_type == "list":
        try:
            normalized_extraction = str(extraction)
        except Exception:
            normalized_extraction = None

    return normalized_extraction


def safe_equal(prediction, answer):
    """
    Check if the prediction is equal to the answer, even if they are of different types
    """
    try:
        if prediction == answer:
            return True
        return False
    except Exception:
        return False


Baber's avatar
Baber committed
94
95
96
97
98
99
def extract_answer(response: str, problem: dict) -> str:
    question_type = problem["question_type"]
    answer_type = problem["answer_type"]
    choices = problem["choices"]
    # query = problem["query"]
    # pid = problem['pid']
Baber's avatar
Baber committed
100

Baber's avatar
Baber committed
101
102
    if response == "":
        return ""
Baber's avatar
Baber committed
103

Baber's avatar
Baber committed
104
105
106
107
108
109
    ### This is not in the original code:
    extract = re.findall(r"[tT]he answer is (\d+)", response)
    if extract:
        return str(extract[0])
    ###

Baber's avatar
Baber committed
110
    if question_type == "multi_choice" and response in choices:
Baber's avatar
Baber committed
111
112
        return response

Baber's avatar
Baber committed
113
114
115
116
117
118
    if answer_type == "integer":
        try:
            extraction = int(response)
            return str(extraction)
        except Exception:
            pass
Baber's avatar
Baber committed
119

Baber's avatar
Baber committed
120
121
122
123
124
125
    if answer_type == "float":
        try:
            extraction = str(float(response))
            return extraction
        except Exception:
            pass
Baber's avatar
Baber committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    return ""


# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc: dict, results: list[str]):
    response = results[0]  # noqa: F841
    choices = doc["choices"]
    question_type = doc["question_type"]
    answer_type = doc["answer_type"]
    precision = doc["precision"]  # noqa: F841
    answer = doc["answer"]
    extracted_answer = extract_answer(response, doc)
    normalized_extraction = normalize_extracted_answer(
        extracted_answer, choices, question_type, answer_type, precision
    )
    res = safe_equal(normalized_extraction, answer)
    return {"acc": 1.0} if res else {"acc": 0.0}