registry.py 4.44 KB
Newer Older
1
import collections
lintangsutawika's avatar
lintangsutawika committed
2
import logging
lintangsutawika's avatar
lintangsutawika committed
3
4
from functools import partial

lintangsutawika's avatar
lintangsutawika committed
5
6
import evaluate

7
from lm_eval.api.model import LM
lintangsutawika's avatar
lintangsutawika committed
8

lintangsutawika's avatar
lintangsutawika committed
9

10
eval_logger = logging.getLogger("lm-eval")
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

MODEL_REGISTRY = {}


def register_model(*names):
    # either pass a list or a single alias.
    # function receives them as a tuple of strings

    def decorate(cls):
        for name in names:
            assert issubclass(
                cls, LM
            ), f"Model '{name}' ({cls.__name__}) must extend LM class"

            assert (
                name not in MODEL_REGISTRY
            ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."

            MODEL_REGISTRY[name] = cls
        return cls

    return decorate


def get_model(model_name):
haileyschoelkopf's avatar
haileyschoelkopf committed
36
37
38
    try:
        return MODEL_REGISTRY[model_name]
    except KeyError:
39
40
41
        raise ValueError(
            f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
        )
42
43
44
45


TASK_REGISTRY = {}
GROUP_REGISTRY = {}
46
ALL_TASKS = set()
47
48
49
50
51
52
53
54
55
56
func2task_index = {}


def register_task(name):
    def decorate(fn):
        assert (
            name not in TASK_REGISTRY
        ), f"task named '{name}' conflicts with existing registered task!"

        TASK_REGISTRY[name] = fn
57
        ALL_TASKS.add(name)
58
59
60
61
62
63
64
65
66
67
68
69
70
        func2task_index[fn.__name__] = name
        return fn

    return decorate


def register_group(name):
    def decorate(fn):
        func_name = func2task_index[fn.__name__]
        if name in GROUP_REGISTRY:
            GROUP_REGISTRY[name].append(func_name)
        else:
            GROUP_REGISTRY[name] = [func_name]
71
            ALL_TASKS.add(name)
72
73
74
75
76
        return fn

    return decorate


77
78
METRIC_REGISTRY = collections.defaultdict(dict)
AGGREGATION_REGISTRY = collections.defaultdict(dict)
79
80

DEFAULT_METRIC_REGISTRY = {
81
82
83
84
    "loglikelihood": [],
    "loglikelihood_rolling": [],
    "multiple_choice": [],
    "generate_until": [],
85
86
87
}


88
def register_metric(
lintangsutawika's avatar
lintangsutawika committed
89
    metric=None,
90
91
    higher_is_better=None,
    output_type=None,
92
    aggregation=None,
93
):
94
95
    # TODO: do we want to enforce a certain interface to registered metrics?
    def decorate(fn):
lintangsutawika's avatar
lintangsutawika committed
96
        if isinstance(metric, str):
lintangsutawika's avatar
lintangsutawika committed
97
            metric_list = [metric]
lintangsutawika's avatar
lintangsutawika committed
98
        elif isinstance(metric, list):
lintangsutawika's avatar
lintangsutawika committed
99
100
101
            metric_list = metric

        for _metric in metric_list:
102
103
104
105
            METRIC_REGISTRY[_metric]["function"] = fn

            if aggregation is not None:
                METRIC_REGISTRY[_metric]["aggregation"] = aggregation
lintangsutawika's avatar
lintangsutawika committed
106
107

            if higher_is_better is not None:
108
                METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
lintangsutawika's avatar
lintangsutawika committed
109
110

            if output_type is not None:
lintangsutawika's avatar
lintangsutawika committed
111
                if isinstance(output_type, str):
lintangsutawika's avatar
lintangsutawika committed
112
                    output_type_list = [output_type]
lintangsutawika's avatar
lintangsutawika committed
113
                elif isinstance(output_type, list):
lintangsutawika's avatar
lintangsutawika committed
114
115
116
117
118
                    output_type_list = output_type

                for _output_type in output_type_list:
                    DEFAULT_METRIC_REGISTRY[_output_type].append(_metric)

119
120
121
122
123
        return fn

    return decorate


124
125
126
127
128
129
130
131
def get_metric(name):
    if name in METRIC_REGISTRY:
        return METRIC_REGISTRY[name]
    else:
        eval_logger.error(f"Could not find registered metric '{name}' in lm-eval")


def get_evaluate(name, **kwargs):
132
    try:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

        class HFEvaluateAdaptor:
            def __init__(self, name, **kwargs):
                self.name = name
                metric_object = evaluate.load(name)
                self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)

            def __call__(self, items):
                refs = list(zip(*items))[0]
                preds = list(zip(*items))[1]

                return self.hf_evaluate_fn(references=refs, predictions=preds)[
                    self.name
                ]

148
        return HFEvaluateAdaptor(name, **kwargs)
149
150
151
    except Exception:
        eval_logger.error(
            f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
152
153
154
        )


155
156
157
158
159
160
161
162
163
164
165
166
167
def register_aggregation(name):
    def decorate(fn):
        assert (
            name not in AGGREGATION_REGISTRY
        ), f"aggregation named '{name}' conflicts with existing registered aggregation!"

        AGGREGATION_REGISTRY[name] = fn
        return fn

    return decorate


def get_aggregation(name):
haileyschoelkopf's avatar
haileyschoelkopf committed
168
    try:
169
        return AGGREGATION_REGISTRY[name]
haileyschoelkopf's avatar
haileyschoelkopf committed
170
    except KeyError:
171
        eval_logger.warning(
172
            "{} not a registered aggregation metric!".format(name),
173
        )