mathbench.py 3.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import copy
import json
import os.path as osp
import re

from datasets import Dataset

from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS

from .base import BaseDataset


def get_number(options):
    result_string = ''
    for i, option in enumerate(options, start=ord('A')):
        result_string += f'{chr(i)}. {option}\n'
    return result_string


def get_circular_example(entry, id):
    """For given example, generate four circular examples."""
    # Only 4 options is supported for current circular eval.
    circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
    data = []
    for c in circular_patterns:
        line = copy.deepcopy(entry)
        options = []
        for i in range(4):
            options.append(line['options'][ord(c[i]) - ord('A')])
        line['options'] = options
        line['answer'] = {
            c[0]: 'A',
            c[1]: 'B',
            c[2]: 'C',
            c[3]: 'D'
        }[line['answer']]
        line['answer'] = str(id) + '--' + line['answer'] + '--' + c
        line['question'] = line['question'].strip() + '\n' + get_number(
            line['options'])
        data.append(line)

    return data


@LOAD_DATASET.register_module()
class MathBenchDataset(BaseDataset):

    @staticmethod
    def load(path: str, name: str, with_circular: bool = True):
        """MathBenth Dataset.

        Args:
            path (str): Path of the mathbench dataset.
            name (str): Name of the target subset.
            with_circular (bool): Whether to create circular dataset for
                single choice question. Defaults to True.
        """
        data = []
        filename = osp.join(path, f'{name}.jsonl')
Fengzhe Zhou's avatar
Fengzhe Zhou committed
60
        with open(filename, 'r', encoding='utf-8') as infile:
61
62
63
64
65
66
67
68
69
70
71
72
73
            for id, line in enumerate(infile):
                entry = json.loads(line)
                if 'cloze' in name:
                    data.append({
                        'question': entry['question'].strip(),
                        'answer': entry['answer'].strip()
                    })
                else:
                    if with_circular:
                        data.extend(get_circular_example(entry, id))
                    else:
                        question = entry['question'].strip(
                        ) + '\n' + get_number(entry['options'])
liushz's avatar
liushz committed
74
                        info = {
75
76
                            'question': question,
                            'answer': entry['answer'].strip()
liushz's avatar
liushz committed
77
78
79
80
81
82
                        }
                        # For PPL evaluation
                        for i in range(4):
                            info[chr(ord('A') +
                                     i)] = entry['options'][i].strip()
                        data.append(info)
83
84
85
86
87
88
89

        dataset = Dataset.from_list(data)
        return dataset


@TEXT_POSTPROCESSORS.register_module()
def mathbench_postprocess(text: str, name: str) -> str:
liushz's avatar
liushz committed
90
    split = False
91
92
93
94
95
96
97
    ans = text
    if '_cn' in name:
        ans_line = ans.split('答案是')
    else:
        ans_line = ans.split('The answer is')
    if len(ans_line) != 1:
        ans = ans_line[1].strip()
liushz's avatar
liushz committed
98
        split = True
99
100

    output = re.sub(r'(\d),(\d)', r'\1\2', ans)
liushz's avatar
liushz committed
101
    numbers = re.findall(r'-?\d*\.?/?\d+|\d+', output)
liushz's avatar
liushz committed
102

103
    if numbers:
liushz's avatar
liushz committed
104
        return numbers[0] if split else numbers[-1]
105
106

    return ans