write_out.py 2.28 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import numpy as np
3
import json
Jason Phang's avatar
Jason Phang committed
4
5
6
import os
import random
from lm_eval import tasks
Leo Gao's avatar
Leo Gao committed
7
from lm_eval.utils import join_iters
Jason Phang's avatar
Jason Phang committed
8
9
10
11
12
13
14
15

EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_base_path', required=True)
    parser.add_argument('--tasks', default="all_tasks")
16
    parser.add_argument('--description_path', default=None)
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--sets', type=str, default="val") # example: val,test
Jason Phang's avatar
Jason Phang committed
18
    parser.add_argument('--num_fewshot', type=int, default=1)
Leo Gao's avatar
Leo Gao committed
19
    parser.add_argument('--seed', type=int, default=42)
Jason Phang's avatar
Jason Phang committed
20
21
22
23
24
25
26
27
28
29
30
31
    parser.add_argument('--num_examples', type=int, default=1)
    return parser.parse_args()


def main():
    args = parse_args()
    np.random.seed(args.seed)

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
32
    task_dict = tasks.get_task_dict(task_names)
33
34
35
36
37
38

    description_dict = {}
    if args.description_path:
        with open(args.description_path, 'r') as f:
            description_dict = json.load(f)

Jason Phang's avatar
Jason Phang committed
39
40
    os.makedirs(args.output_base_path, exist_ok=True)
    for task_name, task in task_dict.items():
Leo Gao's avatar
Leo Gao committed
41
42
        rnd = random.Random()
        rnd.seed(args.seed)
Leo Gao's avatar
Leo Gao committed
43
44
45
46

        iters = []

        for set in args.sets.split(","):
jeffhsu3's avatar
jeffhsu3 committed
47
48
            if set == 'train' and task.has_training_docs():
                docs = task.training_docs()
Leo Gao's avatar
Leo Gao committed
49
50
51
52
53
            if set == 'val' and task.has_validation_docs():
                docs = task.validation_docs()
            if set == 'test' and task.has_test_docs():
                docs = task.test_docs()
            iters.append(docs)
jeffhsu3's avatar
jeffhsu3 committed
54

Leo Gao's avatar
Leo Gao committed
55
56
        docs = join_iters(iters)

57
58
        description = description_dict[task_name] if description_dict and task_name in description_dict else ""

Jason Phang's avatar
Jason Phang committed
59
        with open(os.path.join(args.output_base_path, task_name), "w") as f:
60
            for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
Jason Phang's avatar
Jason Phang committed
61
62
63
64
                f.write(EXAMPLE_DIVIDER.format(i=i))
                ctx = task.fewshot_context(
                    doc=doc,
                    num_fewshot=args.num_fewshot,
65
66
                    rnd=rnd,
                    description=description
Jason Phang's avatar
Jason Phang committed
67
68
69
70
71
72
                )
                f.write(ctx + "\n")


if __name__ == "__main__":
    main()