write_out.py 2.54 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import numpy as np
3
import json
Jason Phang's avatar
Jason Phang committed
4
5
6
import os
import random
from lm_eval import tasks
Leo Gao's avatar
Leo Gao committed
7
from lm_eval.utils import join_iters
Jason Phang's avatar
Jason Phang committed
8
9
10
11
12
13

EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"


def parse_args():
    parser = argparse.ArgumentParser()
jordiclive's avatar
jordiclive committed
14
15
16
17
18
19
20
21
    parser.add_argument("--output_base_path", required=True)
    parser.add_argument("--tasks", default="all_tasks")
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--sets", type=str, default="val")  # example: val,test
    parser.add_argument("--num_fewshot", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_examples", type=int, default=1)
    parser.add_argument("--description_dict_path", default=None)
Jason Phang's avatar
Jason Phang committed
22
23
24
25
26
27
28
29
30
31
32
    return parser.parse_args()


def main():
    args = parse_args()
    np.random.seed(args.seed)

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
33
    task_dict = tasks.get_task_dict_promptsource(task_names)
34
35

    description_dict = {}
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
36
    if args.description_dict_path:
jordiclive's avatar
jordiclive committed
37
        with open(args.description_dict_path, "r") as f:
38
39
            description_dict = json.load(f)

Jason Phang's avatar
Jason Phang committed
40
41
    os.makedirs(args.output_base_path, exist_ok=True)
    for task_name, task in task_dict.items():
Leo Gao's avatar
Leo Gao committed
42
43
        rnd = random.Random()
        rnd.seed(args.seed)
Leo Gao's avatar
Leo Gao committed
44
45
46
47

        iters = []

        for set in args.sets.split(","):
jordiclive's avatar
jordiclive committed
48
            if set == "train" and task.has_training_docs():
jeffhsu3's avatar
jeffhsu3 committed
49
                docs = task.training_docs()
jordiclive's avatar
jordiclive committed
50
            if set == "val" and task.has_validation_docs():
Leo Gao's avatar
Leo Gao committed
51
                docs = task.validation_docs()
jordiclive's avatar
jordiclive committed
52
            if set == "test" and task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
53
54
                docs = task.test_docs()
            iters.append(docs)
jeffhsu3's avatar
jeffhsu3 committed
55

Leo Gao's avatar
Leo Gao committed
56
57
        docs = join_iters(iters)

jordiclive's avatar
jordiclive committed
58
59
60
61
62
63
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
        task_name = task_name.replace("/", "_")
Jason Phang's avatar
Jason Phang committed
64
        with open(os.path.join(args.output_base_path, task_name), "w") as f:
jordiclive's avatar
jordiclive committed
65
66
67
68
69
            for i, doc in (
                zip(range(args.num_examples), docs)
                if args.num_examples > 0
                else enumerate(docs)
            ):
Jason Phang's avatar
Jason Phang committed
70
                f.write(EXAMPLE_DIVIDER.format(i=i))
jon-tow's avatar
jon-tow committed
71
                ctx, _ = task.fewshot_context(
Jason Phang's avatar
Jason Phang committed
72
73
                    doc=doc,
                    num_fewshot=args.num_fewshot,
74
                    rnd=rnd,
jordiclive's avatar
jordiclive committed
75
                    description=description,
Jason Phang's avatar
Jason Phang committed
76
77
78
79
80
81
                )
                f.write(ctx + "\n")


if __name__ == "__main__":
    main()