evaluator.py 3.9 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
Leo Gao's avatar
Leo Gao committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

    task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]

    results = collections.defaultdict(dict)

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}

    # get lists of each type of requeste
    for task_name, task in task_dict_items:
        #default to validation doc, fall back to test doc if validation unavailable
        # TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
        if task.has_validation_docs():
            task_doc_func = task.validation_docs
        elif task.has_test_docs():
            task_doc_func = task.test_docs

Leo Gao's avatar
Leo Gao committed
33
34
35
36
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Leo Gao's avatar
Typo  
Leo Gao committed
37
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
38
39

        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
Leo Gao's avatar
Leo Gao committed
40
41
42
43
44
45
46
47
48
            docs[(task_name, doc_id)] = doc

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=provide_description,
                num_fewshot=num_fewshot,
            )

            reqs = task.construct_requests(doc, ctx)
Leo Gao's avatar
Leo Gao committed
49
            if not isinstance(reqs, (list, tuple)): reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            for i, req in enumerate(reqs):
                requests[req.type].append(req)
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
                requests_origin[req.type].append((i, task_name, doc, doc_id))

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
        # TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
        # only in index. We could implement some kind of caching, but that would be more of a bandaid
        # solution. we could also implement some kind of autogrouping here; they should end up next to each other.

        resps = getattr(lm, reqtype)([req.args for req in reqs])

        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
    
91
    return results