_generate_configs.py 5.56 KB
Newer Older
haileyschoelkopf's avatar
haileyschoelkopf committed
1
2
3
"""
Take in a YAML, and output all other splits with this YAML
"""
4

haileyschoelkopf's avatar
haileyschoelkopf committed
5
import argparse
6
import logging
7
import os
haileyschoelkopf's avatar
haileyschoelkopf committed
8

9
import yaml
haileyschoelkopf's avatar
haileyschoelkopf committed
10
11
from tqdm import tqdm

12
13

eval_logger = logging.getLogger(__name__)
haileyschoelkopf's avatar
haileyschoelkopf committed
14

15

haileyschoelkopf's avatar
haileyschoelkopf committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
SUBJECTS = {
    "agronomy": "农学",
    "anatomy": "解剖学",
    "ancient_chinese": "古汉语",
    "arts": "艺术学",
    "astronomy": "天文学",
    "business_ethics": "商业伦理",
    "chinese_civil_service_exam": "中国公务员考试",
    "chinese_driving_rule": "中国驾驶规则",
    "chinese_food_culture": "中国饮食文化",
    "chinese_foreign_policy": "中国外交政策",
    "chinese_history": "中国历史",
    "chinese_literature": "中国文学",
    "chinese_teacher_qualification": "中国教师资格",
    "clinical_knowledge": "临床知识",
    "college_actuarial_science": "大学精算学",
    "college_education": "大学教育学",
    "college_engineering_hydrology": "大学工程水文学",
    "college_law": "大学法律",
    "college_mathematics": "大学数学",
    "college_medical_statistics": "大学医学统计",
    "college_medicine": "大学医学",
    "computer_science": "计算机科学",
    "computer_security": "计算机安全",
    "conceptual_physics": "概念物理学",
    "construction_project_management": "建设工程管理",
    "economics": "经济学",
    "education": "教育学",
    "electrical_engineering": "电气工程",
    "elementary_chinese": "小学语文",
    "elementary_commonsense": "小学常识",
    "elementary_information_and_technology": "小学信息技术",
    "elementary_mathematics": "初等数学",
    "ethnology": "民族学",
    "food_science": "食品科学",
    "genetics": "遗传学",
    "global_facts": "全球事实",
    "high_school_biology": "高中生物",
    "high_school_chemistry": "高中化学",
    "high_school_geography": "高中地理",
    "high_school_mathematics": "高中数学",
    "high_school_physics": "高中物理学",
    "high_school_politics": "高中政治",
    "human_sexuality": "人类性行为",
    "international_law": "国际法学",
    "journalism": "新闻学",
    "jurisprudence": "法理学",
    "legal_and_moral_basis": "法律与道德基础",
    "logical": "逻辑学",
    "machine_learning": "机器学习",
    "management": "管理学",
    "marketing": "市场营销",
    "marxist_theory": "马克思主义理论",
    "modern_chinese": "现代汉语",
    "nutrition": "营养学",
    "philosophy": "哲学",
    "professional_accounting": "专业会计",
    "professional_law": "专业法学",
    "professional_medicine": "专业医学",
    "professional_psychology": "专业心理学",
    "public_relations": "公共关系",
    "security_study": "安全研究",
    "sociology": "社会学",
    "sports_science": "体育学",
    "traditional_chinese_medicine": "中医中药",
    "virology": "病毒学",
    "world_history": "世界历史",
    "world_religions": "世界宗教",
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_yaml_path", required=True)
    parser.add_argument("--save_prefix_path", default="cmmlu")
    parser.add_argument("--cot_prompt_path", default=None)
    parser.add_argument("--task_prefix", default="")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
    base_yaml_name = os.path.split(args.base_yaml_path)[-1]
101
    with open(args.base_yaml_path, encoding="utf-8") as f:
haileyschoelkopf's avatar
haileyschoelkopf committed
102
103
104
105
106
        base_yaml = yaml.full_load(f)

    if args.cot_prompt_path is not None:
        import json

107
        with open(args.cot_prompt_path, encoding="utf-8") as f:
haileyschoelkopf's avatar
haileyschoelkopf committed
108
109
110
111
112
113
            cot_file = json.load(f)

    for subject_eng, subject_zh in tqdm(SUBJECTS.items()):
        if args.cot_prompt_path is not None:
            description = cot_file[subject_eng]
        else:
114
115
116
            description = (
                f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
117
118
119
120
121
122
123
124
125
126
127
128

        yaml_dict = {
            "include": base_yaml_name,
            "task": f"cmmlu_{args.task_prefix}_{subject_eng}"
            if args.task_prefix != ""
            else f"cmmlu_{subject_eng}",
            "dataset_name": subject_eng,
            "description": description,
        }

        file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
        eval_logger.info(f"Saving yaml for subset {subject_eng} to {file_save_path}")
129
        with open(file_save_path, "w", encoding="utf-8") as yaml_file:
haileyschoelkopf's avatar
haileyschoelkopf committed
130
131
132
133
134
135
136
            yaml.dump(
                yaml_dict,
                yaml_file,
                width=float("inf"),
                allow_unicode=True,
                default_style='"',
            )
Lintang Sutawika's avatar
Lintang Sutawika committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    # write group config out

    group_yaml_dict = {
        "group": "cmmlu",
        "task": [
            (
                f"cmmlu_{args.task_prefix}_{subject_eng}"
                if args.task_prefix != ""
                else f"cmmlu_{subject_eng}"
            )
            for subject_eng in SUBJECTS.keys()
        ],
        "aggregate_metric_list": [
            {"metric": "acc", "aggregation": "mean", "weight_by_size": True},
            {"metric": "acc_norm", "aggregation": "mean", "weight_by_size": True},
        ],
        "metadata": {"version": 0.0},
    }

    file_save_path = "_" + args.save_prefix_path + ".yaml"

    with open(file_save_path, "w", encoding="utf-8") as group_yaml_file:
        yaml.dump(
            group_yaml_dict,
            group_yaml_file,
            width=float("inf"),
            allow_unicode=True,
            default_style='"',
        )