main.py 1.17 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3

Jason Phang's avatar
Jason Phang committed
4
5
import models
import tasks
Leo Gao's avatar
Leo Gao committed
6

Jason Phang's avatar
Jason Phang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
    parser.add_argument('--new_fewshot', action="store_true")
    return parser.parse_args()


def main():
    args = parse_args()
    model = models.get_model(args.model).create_from_arg_string(args.model_args)
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
    task_list = {
        task_name: tasks.get_task(task_name)()
        for task_name in task_names
    }
    results = {}
    for task_name, task in task_list:
        if not task.has_validation_docs():
            continue
        result = task.evaluate(
            docs=task.validation_docs(),
            provide_description=args.provide_description,
            num_fewshot=args.new_fewshot,
        )
        results[task_name] = result
    print(json.dumps(results, indent=2))

if __name__ == "__main__":
    main()