build_benchmark.py 2.02 KB
Newer Older
1
import argparse
2
import os
3

4
import yaml
5
from promptsource.templates import DatasetTemplates
6
from tqdm import tqdm
7
8
9
10

# from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger

11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# from lm_eval.tasks import include_task_folder


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark_name", required=True)
    parser.add_argument("--benchmark_path", required=True)
    parser.add_argument("--task_save_path", default="lm_eval/tasks/")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    with open(args.benchmark_path) as file:
        TASK_LIST = yaml.full_load(file)
        for task in tqdm(TASK_LIST):
            eval_logger.info(f"Processing {task}")

            dataset_name = task["dataset_path"]
            if "dataset_name" in task:
                subset_name = task["dataset_name"]
lintangsutawika's avatar
lintangsutawika committed
34
                file_subdir = f"{dataset_name}/{subset_name}"
35
36
            else:
                subset_name = None
lintangsutawika's avatar
lintangsutawika committed
37
                file_subdir = f"{dataset_name}"
38

lintangsutawika's avatar
lintangsutawika committed
39
            file_path = os.path.join(args.task_save_path, file_subdir, "promptsource/")
40
41
42
43
44
45
46
47
48
49
50

            os.makedirs(file_path, exist_ok=True)

            if subset_name is None:
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
                prompts = DatasetTemplates(
                    dataset_name=dataset_name, subset_name=subset_name
                )

            for idx, prompt_name in enumerate(prompts.all_template_names):
lintangsutawika's avatar
lintangsutawika committed
51
                full_file_name = f"promptsource_{idx}.yaml"
52
53
54
55
56
57
58
59
60
61
                config_dict = {
                    "group": args.benchmark_name,
                    "include": "promptsource_template.yaml",
                    "use_prompts": f"promptsource:{prompt_name}",
                }

                file_save_path = os.path.join(file_path, full_file_name)
                eval_logger.info(f"Save to {file_save_path}")
                with open(file_save_path, "w") as yaml_file:
                    yaml.dump(config_dict, yaml_file)