_generate_configs.py 2.63 KB
Newer Older
lintangsutawika's avatar
updates  
lintangsutawika committed
1
2
3
"""
Take in a YAML, and output all other splits with this YAML
"""
4

5
import argparse
lintangsutawika's avatar
updates  
lintangsutawika committed
6
7
import os
import re
lintangsutawika's avatar
add bbh  
lintangsutawika committed
8

lintangsutawika's avatar
updates  
lintangsutawika committed
9
import datasets
10
11
import requests
import yaml
lintangsutawika's avatar
add bbh  
lintangsutawika committed
12
13
from tqdm import tqdm

lintangsutawika's avatar
update  
lintangsutawika committed
14

lintangsutawika's avatar
updates  
lintangsutawika committed
15
16
17
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_yaml_path", required=True)
lintangsutawika's avatar
lintangsutawika committed
18
    parser.add_argument("--save_prefix_path", default="zeroshot")
lintangsutawika's avatar
update  
lintangsutawika committed
19
20
    parser.add_argument("--cot", default=False)
    parser.add_argument("--fewshot", default=False)
lintangsutawika's avatar
updates  
lintangsutawika committed
21
22
23
24
25
26
    parser.add_argument("--task_prefix", default="")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
lintangsutawika's avatar
add bbh  
lintangsutawika committed
27

lintangsutawika's avatar
updates  
lintangsutawika committed
28
29
    # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
    base_yaml_name = os.path.split(args.base_yaml_path)[-1]
30
    with open(args.base_yaml_path, encoding="utf-8") as f:
lintangsutawika's avatar
updates  
lintangsutawika committed
31
32
33
34
        base_yaml = yaml.full_load(f)

    base_doc_to_text = "Q: {{input}}\nA:"
    answer_regex = re.compile("(?<=answer is )(.*)(?=.)")
lintangsutawika's avatar
add bbh  
lintangsutawika committed
35
36
37

    dataset_path = "lukaemon/bbh"
    for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
lintangsutawika's avatar
update  
lintangsutawika committed
38
39
40
        resp = requests.get(
            f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
        ).content.decode("utf-8")
lintangsutawika's avatar
updates  
lintangsutawika committed
41
        prompt = resp.split("\n-----\n")[-1]
42
        description, *few_shot = prompt.split("\n\n")
lintangsutawika's avatar
updates  
lintangsutawika committed
43
44
45
46

        prefix_doc_to_text = ""
        if args.fewshot:
            if args.cot:
momotori's avatar
momotori committed
47
                prefix_doc_to_text = "\n\n".join(few_shot) + "\n\n"
lintangsutawika's avatar
updates  
lintangsutawika committed
48
49
50
51
            else:
                for shot in few_shot:
                    try:
                        answer = answer_regex.search(shot)[0]
52
                    except Exception as e:
lintangsutawika's avatar
updates  
lintangsutawika committed
53
54
                        print("task", task)
                        print(shot)
55
                        raise e
lintangsutawika's avatar
update  
lintangsutawika committed
56
                    example = shot.split("Let's think step by step.")[0]
lintangsutawika's avatar
updates  
lintangsutawika committed
57
58
59
60
61
62
63
                    prefix_doc_to_text += f"{example}{answer}\n\n"

        doc_to_text = prefix_doc_to_text + base_doc_to_text
        if args.cot:
            doc_to_text = doc_to_text + " Let's think step by step.\n"

        yaml_dict = {
lintangsutawika's avatar
update  
lintangsutawika committed
64
65
66
67
68
69
            "include": base_yaml_name,
            "task": f"bbh_{args.task_prefix}_{task}",
            "dataset_name": task,
            "description": description + "\n\n",
            "doc_to_text": doc_to_text,
        }
lintangsutawika's avatar
updates  
lintangsutawika committed
70
71

        file_save_path = args.save_prefix_path + f"/{task}.yaml"
72
        print(f"Saving yaml for subset {task} to {file_save_path}")
73
        with open(file_save_path, "w", encoding="utf-8") as yaml_file:
lintangsutawika's avatar
update  
lintangsutawika committed
74
75
76
77
78
79
80
            yaml.dump(
                yaml_dict,
                yaml_file,
                width=float("inf"),
                allow_unicode=True,
                default_style='"',
            )