write_out.py 2.93 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import os
Michael Pieler's avatar
Michael Pieler committed
3
import random
4
5
6

import numpy as np

Michael Pieler's avatar
Michael Pieler committed
7
from lm_eval import tasks
Baber Abbasi's avatar
Baber Abbasi committed
8
from lm_eval.evaluator_utils import get_task_list
9
from lm_eval.tasks import TaskManager
10
11
from lm_eval.utils import eval_logger, join_iters

Jason Phang's avatar
Jason Phang committed
12
13
14
15
16
17

EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"


def parse_args():
    parser = argparse.ArgumentParser()
18
    parser.add_argument("--output_base_path", "--output_path", required=True)
Fabrizio Milo's avatar
Fabrizio Milo committed
19
20
21
22
23
    parser.add_argument("--tasks", default="all_tasks")
    parser.add_argument("--sets", type=str, default="val")  # example: val,test
    parser.add_argument("--num_fewshot", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_examples", type=int, default=1)
24
25
26
27
28
29
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
        help="Additional path to include if there are external tasks to include.",
    )
30
31
32
33
34
35
    parser.add_argument(
        "--verbosity",
        type=str,
        default="INFO",
        help="Log error when tasks are not registered.",
    )
Jason Phang's avatar
Jason Phang committed
36
37
38
39
40
41
42
    return parser.parse_args()


def main():
    args = parse_args()
    np.random.seed(args.seed)

43
44
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
45
46

    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
47

Jason Phang's avatar
Jason Phang committed
48
    if args.tasks == "all_tasks":
49
        task_names = task_manager.all_tasks
Jason Phang's avatar
Jason Phang committed
50
51
    else:
        task_names = args.tasks.split(",")
52
    task_dict = tasks.get_task_dict(task_names, task_manager)
Michael Pieler's avatar
Michael Pieler committed
53
54

    os.makedirs(args.output_base_path, exist_ok=True)
Baber Abbasi's avatar
Baber Abbasi committed
55
56
    for task in [x.task for x in get_task_list(task_dict)]:
        task_name = task.config.task
Leo Gao's avatar
Leo Gao committed
57
58
        rnd = random.Random()
        rnd.seed(args.seed)
Leo Gao's avatar
Leo Gao committed
59
60
61
62

        iters = []

        for set in args.sets.split(","):
63
            docs = None
Fabrizio Milo's avatar
Fabrizio Milo committed
64
            if set == "train" and task.has_training_docs():
jeffhsu3's avatar
jeffhsu3 committed
65
                docs = task.training_docs()
Fabrizio Milo's avatar
Fabrizio Milo committed
66
            if set == "val" and task.has_validation_docs():
Leo Gao's avatar
Leo Gao committed
67
                docs = task.validation_docs()
Fabrizio Milo's avatar
Fabrizio Milo committed
68
            if set == "test" and task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
69
                docs = task.test_docs()
70
71
            if docs is not None:
                iters.append(docs)
jeffhsu3's avatar
jeffhsu3 committed
72

73
74
75
76
77
        if len(iters) == 0:
            raise ValueError(
                f"Passed --sets '{args.sets}' but this task has no splits which match. Please specify a different --sets value."
            )

Leo Gao's avatar
Leo Gao committed
78
79
        docs = join_iters(iters)

Lintang Sutawika's avatar
Lintang Sutawika committed
80
81
82
        with open(
            os.path.join(args.output_base_path, task_name), "w", encoding="utf8"
        ) as f:
Fabrizio Milo's avatar
Fabrizio Milo committed
83
84
85
86
87
            for i, doc in (
                zip(range(args.num_examples), docs)
                if args.num_examples > 0
                else enumerate(docs)
            ):
Jason Phang's avatar
Jason Phang committed
88
                f.write(EXAMPLE_DIVIDER.format(i=i))
89
                ctx = task.fewshot_context(
Jason Phang's avatar
Jason Phang committed
90
91
92
93
94
95
96
97
                    doc=doc,
                    num_fewshot=args.num_fewshot,
                )
                f.write(ctx + "\n")


if __name__ == "__main__":
    main()