ja_leaderboard_mgsm.py 837 Bytes
Newer Older
mtkachenko's avatar
mtkachenko committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import re


_INVALID_ANSWER = "[invalid]"

_ANSWER_REGEX = re.compile(r"(\-?[0-9\.\,]+)")


def _extract_answer(completion):
    matches = _ANSWER_REGEX.findall(completion)
    if matches:
        match_str = matches[-1].strip(".")
        match_str = match_str.replace(",", "")
        try:
            match_float = float(match_str)
        except ValueError:
            return _INVALID_ANSWER

        if match_float.is_integer():
            return int(match_float)

    return _INVALID_ANSWER


def process_results(doc, results):
    assert (
        len(results) == 1
    ), f"results should be a list with 1 str element, but is {results}"

    completion = results[0]
    extracted_answer = _extract_answer(completion)
    answer = doc["answer_number"]
    acc = extracted_answer == answer
    return {
        "acc": acc,
    }