build_benchmark.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import yaml
import argparse

from tqdm import tqdm
from promptsource.templates import DatasetTemplates

from lm_eval import utils

# from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger

# from lm_eval.tasks import include_task_folder


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark_name", required=True)
    parser.add_argument("--benchmark_path", required=True)
    parser.add_argument("--task_save_path", default="lm_eval/tasks/")
    return parser.parse_args()


if __name__ == "__main__":

    args = parse_args()

    with open(args.benchmark_path) as file:
        TASK_LIST = yaml.full_load(file)
        for task in tqdm(TASK_LIST):
            eval_logger.info(f"Processing {task}")

            dataset_name = task["dataset_path"]
            if "dataset_name" in task:
                subset_name = task["dataset_name"]
lintangsutawika's avatar
lintangsutawika committed
36
                file_subdir = f"{dataset_name}/{subset_name}"
37
38
            else:
                subset_name = None
lintangsutawika's avatar
lintangsutawika committed
39
                file_subdir = f"{dataset_name}"
40

lintangsutawika's avatar
lintangsutawika committed
41
            file_path = os.path.join(args.task_save_path, file_subdir, "promptsource/")
42
43
44
45
46
47
48
49
50
51
52

            os.makedirs(file_path, exist_ok=True)

            if subset_name is None:
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
                prompts = DatasetTemplates(
                    dataset_name=dataset_name, subset_name=subset_name
                )

            for idx, prompt_name in enumerate(prompts.all_template_names):
lintangsutawika's avatar
lintangsutawika committed
53
                full_file_name = f"promptsource_{idx}.yaml"
54
55
56
57
58
59
60
61
62
63
                config_dict = {
                    "group": args.benchmark_name,
                    "include": "promptsource_template.yaml",
                    "use_prompts": f"promptsource:{prompt_name}",
                }

                file_save_path = os.path.join(file_path, full_file_name)
                eval_logger.info(f"Save to {file_save_path}")
                with open(file_save_path, "w") as yaml_file:
                    yaml.dump(config_dict, yaml_file)