ceval.py 7.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import copy
import csv
import os
from typing import Dict, List

from colossalai.logging import DistributedLogger

from .base import BaseDataset

ceval_subject_mapping = {
    "computer_network": ["Computer Network", "计算机网络", "STEM"],
    "operating_system": ["Operating System", "操作系统", "STEM"],
    "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"],
    "college_programming": ["College Programming", "大学编程", "STEM"],
    "college_physics": ["College Physics", "大学物理", "STEM"],
    "college_chemistry": ["College Chemistry", "大学化学", "STEM"],
    "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"],
    "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"],
    "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"],
    "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"],
    "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"],
    "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"],
    "high_school_physics": ["High School Physics", "高中物理", "STEM"],
    "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"],
    "high_school_biology": ["High School Biology", "高中生物", "STEM"],
    "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"],
    "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"],
    "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"],
    "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"],
    "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"],
    "college_economics": ["College Economics", "大学经济学", "Social Science"],
    "business_administration": ["Business Administration", "工商管理", "Social Science"],
    "marxism": ["Marxism", "马克思主义基本原理", "Social Science"],
    "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"],
    "education_science": ["Education Science", "教育学", "Social Science"],
    "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"],
    "high_school_politics": ["High School Politics", "高中政治", "Social Science"],
    "high_school_geography": ["High School Geography", "高中地理", "Social Science"],
    "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"],
    "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"],
    "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"],
    "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"],
    "logic": ["Logic", "逻辑学", "Humanities"],
    "law": ["Law", "法学", "Humanities"],
    "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"],
    "art_studies": ["Art Studies", "艺术学", "Humanities"],
    "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"],
    "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"],
    "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"],
    "high_school_history": ["High School History", "高中历史", "Humanities"],
    "middle_school_history": ["Middle School History", "初中历史", "Humanities"],
    "civil_servant": ["Civil Servant", "公务员", "Other"],
    "sports_science": ["Sports Science", "体育学", "Other"],
    "plant_protection": ["Plant Protection", "植物保护", "Other"],
    "basic_medicine": ["Basic Medicine", "基础医学", "Other"],
    "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"],
    "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
    "accountant": ["Accountant", "注册会计师", "Other"],
    "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
    "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
    "tax_accountant": ["Tax Accountant", "税务师", "Other"],
    "physician": ["Physician", "医师资格", "Other"],
}

default_inference_kwargs = {
    "calculate_loss": False,
    "all_classes": ["A", "B", "C", "D"],
    "language": "Chinese",
    "pretrain": False,
    "max_new_tokens": 32,
}


74
75
def get_few_shot_data(data: List[Dict], subject):
    few_shot_data = [f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。"]
76
77
78
79
80
81
82
83
84
85
86
87
88
    for i in data:
        few_shot_data.append(i["input"] + i["target"])
    return few_shot_data


class CEvalDataset(BaseDataset):
    """
    Dataset class for CEval dataset.
    Data source: https://huggingface.co/datasets/ceval/ceval-exam
    This dataset class will convert the original dataset into the inference dataset.
    """

    @staticmethod
89
90
91
    def load(
        path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
    ) -> List[Dict]:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        dataset = {"dev": {}, "test": {}}
        for split in ["dev", "test"]:
            files = os.listdir(os.path.join(path, split))
            files.sort()

            for file in files:
                subject = file[0 : -len(f"_{split}.csv")]
                subject = ceval_subject_mapping[subject][1]

                file_dir = os.path.join(path, split, file)

                dataset[split][subject] = {"data": []}

                # It's been tested that each data sample in one subcategory have same inference arguments.
                dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)

                if split == "test" and few_shot:
                    dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
110
                        dataset["dev"][subject]["data"], subject
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                    )

                with open(file_dir, encoding="utf-8") as f:
                    reader = csv.reader(f)
                    _ = next(reader)
                    for row in reader:
                        # Dev split have answer and explanation so len(row) is 8
                        # But test split doesn't contain answer and explanation, so len(row) is 6
                        assert len(row) >= 6
                        choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
                        data_sample = {
                            "dataset": "ceval",
                            "split": split,
                            "category": subject,
                            "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。",
                            "input": f"题目:{row[1]}\n{choices}\n答案:",
                            "output": "",
                            "target": row[6] if split == "dev" else "",
                            "id": int(row[0]),
                        }

                        dataset[split][subject]["data"].append(data_sample)

        return dataset