test_description_dict.py 1020 Bytes
Newer Older
1
2
3
4
5
6
7
import json
import argparse
import lm_eval.tasks
import lm_eval.models
from lm_eval.evaluator import evaluate


8
9
10
11
12
13
14
def test_cli_description_dict_path():
    def parse_args():
        parser = argparse.ArgumentParser()
        parser.add_argument('--description_dict_path', default=None)
        parser.add_argument('--num_fewshot', type=int, default=0)
        parser.add_argument('--limit', type=int, default=None)
        return parser.parse_args()
15
16
17
18
19
20
21
    args = parse_args()

    task_names = ['hellaswag', 'copa']
    task_dict = lm_eval.tasks.get_task_dict(task_names)
    lm = lm_eval.models.get_model('dummy')()

    description_dict = {}
22
23
    if args.description_dict_path:
        with open(args.description_dict_path, 'r') as f:
24
25
26
27
28
29
            description_dict = json.load(f)

    num_fewshot = args.num_fewshot
    results = evaluate(
        lm,
        task_dict,
30
        False,
31
32
33
34
35
36
37
        num_fewshot,
        args.limit,
        description_dict
    )


if __name__ == '__main__':
38
    test_cli_description_dict_path()