build_benchmark.py 2.07 KB
Newer Older
1
import argparse
Lintang Sutawika's avatar
Lintang Sutawika committed
2
import logging
3
import os
4

5
import yaml
6
from promptsource.templates import DatasetTemplates
7
from tqdm import tqdm
8

Lintang Sutawika's avatar
Lintang Sutawika committed
9

10
# from lm_eval.api.registry import ALL_TASKS
Lintang Sutawika's avatar
Lintang Sutawika committed
11
eval_logger = logging.getLogger(__name__)
12

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
# from lm_eval.tasks import include_task_folder


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark_name", required=True)
    parser.add_argument("--benchmark_path", required=True)
    parser.add_argument("--task_save_path", default="lm_eval/tasks/")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

28
    with open(args.benchmark_path, encoding="utf-8") as file:
29
30
31
32
33
34
35
        TASK_LIST = yaml.full_load(file)
        for task in tqdm(TASK_LIST):
            eval_logger.info(f"Processing {task}")

            dataset_name = task["dataset_path"]
            if "dataset_name" in task:
                subset_name = task["dataset_name"]
lintangsutawika's avatar
lintangsutawika committed
36
                file_subdir = f"{dataset_name}/{subset_name}"
37
38
            else:
                subset_name = None
lintangsutawika's avatar
lintangsutawika committed
39
                file_subdir = f"{dataset_name}"
40

lintangsutawika's avatar
lintangsutawika committed
41
            file_path = os.path.join(args.task_save_path, file_subdir, "promptsource/")
42
43
44
45
46
47
48
49
50
51
52

            os.makedirs(file_path, exist_ok=True)

            if subset_name is None:
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
                prompts = DatasetTemplates(
                    dataset_name=dataset_name, subset_name=subset_name
                )

            for idx, prompt_name in enumerate(prompts.all_template_names):
lintangsutawika's avatar
lintangsutawika committed
53
                full_file_name = f"promptsource_{idx}.yaml"
54
55
56
57
58
59
60
61
                config_dict = {
                    "group": args.benchmark_name,
                    "include": "promptsource_template.yaml",
                    "use_prompts": f"promptsource:{prompt_name}",
                }

                file_save_path = os.path.join(file_path, full_file_name)
                eval_logger.info(f"Save to {file_save_path}")
62
                with open(file_save_path, "w", encoding="utf-8") as yaml_file:
63
                    yaml.dump(config_dict, yaml_file)