write_out.py 1.54 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
import numpy as np
import os
import random

from lm_eval import tasks

EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_base_path', required=True)
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
    parser.add_argument('--num_fewshot', type=int, default=1)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--num_examples', type=int, default=1)
    return parser.parse_args()


def main():
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
31
    task_dict = tasks.get_task_dict(task_names)
Jason Phang's avatar
Jason Phang committed
32
33
34
    os.makedirs(args.output_base_path, exist_ok=True)
    for task_name, task in task_dict.items():
        if not task.has_validation_docs():
35
36
37
            docs = task.training_docs()
        else:
            docs = task.validation_docs()
Jason Phang's avatar
Jason Phang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
        with open(os.path.join(args.output_base_path, task_name), "w") as f:
            for i, doc in zip(range(args.num_examples), docs):
                f.write(EXAMPLE_DIVIDER.format(i=i))
                ctx = task.fewshot_context(
                    doc=doc,
                    provide_description=args.provide_description,
                    num_fewshot=args.num_fewshot,
                )
                f.write(ctx + "\n")


if __name__ == "__main__":
    main()