_generate_configs.py 2.6 KB
Newer Older
lintangsutawika's avatar
updates  
lintangsutawika committed
1
2
3
4
5
"""
Take in a YAML, and output all other splits with this YAML
"""
import os
import re
lintangsutawika's avatar
add bbh  
lintangsutawika committed
6
import yaml
lintangsutawika's avatar
updates  
lintangsutawika committed
7
8
import requests
import argparse
lintangsutawika's avatar
add bbh  
lintangsutawika committed
9

lintangsutawika's avatar
updates  
lintangsutawika committed
10
import datasets
lintangsutawika's avatar
add bbh  
lintangsutawika committed
11
12
from tqdm import tqdm

lintangsutawika's avatar
updates  
lintangsutawika committed
13
14
from lm_eval import utils

lintangsutawika's avatar
update  
lintangsutawika committed
15

lintangsutawika's avatar
updates  
lintangsutawika committed
16
17
18
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_yaml_path", required=True)
lintangsutawika's avatar
lintangsutawika committed
19
    parser.add_argument("--save_prefix_path", default="zeroshot")
lintangsutawika's avatar
update  
lintangsutawika committed
20
21
    parser.add_argument("--cot", default=False)
    parser.add_argument("--fewshot", default=False)
lintangsutawika's avatar
updates  
lintangsutawika committed
22
23
24
25
26
27
    parser.add_argument("--task_prefix", default="")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
lintangsutawika's avatar
add bbh  
lintangsutawika committed
28

lintangsutawika's avatar
updates  
lintangsutawika committed
29
30
31
32
33
34
35
    # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
    base_yaml_name = os.path.split(args.base_yaml_path)[-1]
    with open(args.base_yaml_path) as f:
        base_yaml = yaml.full_load(f)

    base_doc_to_text = "Q: {{input}}\nA:"
    answer_regex = re.compile("(?<=answer is )(.*)(?=.)")
lintangsutawika's avatar
add bbh  
lintangsutawika committed
36
37
38

    dataset_path = "lukaemon/bbh"
    for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
lintangsutawika's avatar
update  
lintangsutawika committed
39
40
41
        resp = requests.get(
            f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
        ).content.decode("utf-8")
lintangsutawika's avatar
updates  
lintangsutawika committed
42
        prompt = resp.split("\n-----\n")[-1]
43
        description, *few_shot = prompt.split("\n\n")
lintangsutawika's avatar
updates  
lintangsutawika committed
44
45
46
47

        prefix_doc_to_text = ""
        if args.fewshot:
            if args.cot:
momotori's avatar
momotori committed
48
                prefix_doc_to_text = "\n\n".join(few_shot) + "\n\n"
lintangsutawika's avatar
updates  
lintangsutawika committed
49
50
51
52
            else:
                for shot in few_shot:
                    try:
                        answer = answer_regex.search(shot)[0]
lintangsutawika's avatar
lintangsutawika committed
53
                    except Exception:
lintangsutawika's avatar
updates  
lintangsutawika committed
54
55
                        print("task", task)
                        print(shot)
lintangsutawika's avatar
update  
lintangsutawika committed
56
                    example = shot.split("Let's think step by step.")[0]
lintangsutawika's avatar
updates  
lintangsutawika committed
57
58
59
60
61
62
63
                    prefix_doc_to_text += f"{example}{answer}\n\n"

        doc_to_text = prefix_doc_to_text + base_doc_to_text
        if args.cot:
            doc_to_text = doc_to_text + " Let's think step by step.\n"

        yaml_dict = {
lintangsutawika's avatar
update  
lintangsutawika committed
64
65
66
67
68
69
            "include": base_yaml_name,
            "task": f"bbh_{args.task_prefix}_{task}",
            "dataset_name": task,
            "description": description + "\n\n",
            "doc_to_text": doc_to_text,
        }
lintangsutawika's avatar
updates  
lintangsutawika committed
70
71

        file_save_path = args.save_prefix_path + f"/{task}.yaml"
72
        utils.eval_logger.info(f"Saving yaml for subset {task} to {file_save_path}")
lintangsutawika's avatar
updates  
lintangsutawika committed
73
        with open(file_save_path, "w") as yaml_file:
lintangsutawika's avatar
update  
lintangsutawika committed
74
75
76
77
78
79
80
            yaml.dump(
                yaml_dict,
                yaml_file,
                width=float("inf"),
                allow_unicode=True,
                default_style='"',
            )