registry.py 5.23 KB
Newer Older
1
2
3
import os
import evaluate
from lm_eval.api.model import LM
4
5

import logging
lintangsutawika's avatar
lintangsutawika committed
6

7
eval_logger = logging.getLogger("lm-eval")
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

MODEL_REGISTRY = {}


def register_model(*names):
    # either pass a list or a single alias.
    # function receives them as a tuple of strings

    def decorate(cls):
        for name in names:
            assert issubclass(
                cls, LM
            ), f"Model '{name}' ({cls.__name__}) must extend LM class"

            assert (
                name not in MODEL_REGISTRY
            ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."

            MODEL_REGISTRY[name] = cls
        return cls

    return decorate


def get_model(model_name):
haileyschoelkopf's avatar
haileyschoelkopf committed
33
34
35
    try:
        return MODEL_REGISTRY[model_name]
    except KeyError:
36
37
38
        raise ValueError(
            f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
        )
39
40
41
42


TASK_REGISTRY = {}
GROUP_REGISTRY = {}
43
ALL_TASKS = set()
44
45
46
47
48
49
50
51
52
53
func2task_index = {}


def register_task(name):
    def decorate(fn):
        assert (
            name not in TASK_REGISTRY
        ), f"task named '{name}' conflicts with existing registered task!"

        TASK_REGISTRY[name] = fn
54
        ALL_TASKS.add(name)
55
56
57
58
59
60
61
62
63
64
65
66
67
        func2task_index[fn.__name__] = name
        return fn

    return decorate


def register_group(name):
    def decorate(fn):
        func_name = func2task_index[fn.__name__]
        if name in GROUP_REGISTRY:
            GROUP_REGISTRY[name].append(func_name)
        else:
            GROUP_REGISTRY[name] = [func_name]
68
            ALL_TASKS.add(name)
69
70
71
72
73
        return fn

    return decorate


lintangsutawika's avatar
lintangsutawika committed
74
METRIC_FUNCTION_REGISTRY = {}
75
76
77
HIGHER_IS_BETTER_REGISTRY = {}

DEFAULT_METRIC_REGISTRY = {
78
79
80
81
    "loglikelihood": [],
    "loglikelihood_rolling": [],
    "multiple_choice": [],
    "generate_until": [],
82
83
84
}


85
def register_metric(
lintangsutawika's avatar
lintangsutawika committed
86
    metric=None,
87
88
    higher_is_better=None,
    output_type=None,
lintangsutawika's avatar
lintangsutawika committed
89
    # aggregation=None,
90
):
91
92
93
    # TODO: do we want to enforce a certain interface to registered metrics?
    def decorate(fn):

lintangsutawika's avatar
lintangsutawika committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if type(metric) == str:
            metric_list = [metric]
        elif type(metric) == list:
            metric_list = metric

        for _metric in metric_list:
            METRIC_FUNCTION_REGISTRY[_metric] = fn

            if higher_is_better is not None:
                HIGHER_IS_BETTER_REGISTRY[_metric] = higher_is_better

            if output_type is not None:
                if type(output_type) == str:
                    output_type_list = [output_type]
                elif type(output_type) == list:
                    output_type_list = output_type

                for _output_type in output_type_list:
                    DEFAULT_METRIC_REGISTRY[_output_type].append(_metric)

        #     # for key, registry in [
        #     #     ("output_type", OUTPUT_TYPE_REGISTRY),
        #     #     ("metric", METRIC_REGISTRY),
        #     #     ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
        #     #     ("aggregation", METRIC_AGGREGATION_REGISTRY),
        #     # ]:

        #     #     if key in args:
        #     #         value = args[key]
        #     #         assert (
        #     #             value not in registry
        #     #         ), f"{key} named '{value}' conflicts with existing registered {key}!"

        #     #         if key == "metric":
        #     #             registry[name] = fn
        #     #         elif key == "aggregation":
        #     #             registry[name] = AGGREGATION_REGISTRY[value]
        #     #         else:
        #     #             registry[name] = value
133
134
135
136
137
138

        return fn

    return decorate


139
def get_metric(name, hf_evaluate_metric=False):
140

141
    if not hf_evaluate_metric:
lintangsutawika's avatar
lintangsutawika committed
142
143
        if name in METRIC_FUNCTION_REGISTRY:
            return METRIC_FUNCTION_REGISTRY[name]
144
145
146
147
        else:
            eval_logger.warning(
                f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
            )
Chris's avatar
Chris committed
148

149
    try:
150
151
152
153
154
        metric_object = evaluate.load(name)
        return metric_object.compute
    except Exception:
        eval_logger.error(
            f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
155
156
157
        )


lintangsutawika's avatar
lintangsutawika committed
158
159
160
161
162
# def register_aggregation(name):
#     def decorate(fn):
#         assert (
#             name not in AGGREGATION_REGISTRY
#         ), f"aggregation named '{name}' conflicts with existing registered aggregation!"
163

lintangsutawika's avatar
lintangsutawika committed
164
165
#         AGGREGATION_REGISTRY[name] = fn
#         return fn
166

lintangsutawika's avatar
lintangsutawika committed
167
#     return decorate
168
169


lintangsutawika's avatar
lintangsutawika committed
170
# def get_aggregation(name):
171

lintangsutawika's avatar
lintangsutawika committed
172
173
174
175
176
177
#     try:
#         return AGGREGATION_REGISTRY[name]
#     except KeyError:
#         eval_logger.warning(
#             "{} not a registered aggregation metric!".format(name),
#         )
haileyschoelkopf's avatar
haileyschoelkopf committed
178
179


lintangsutawika's avatar
lintangsutawika committed
180
# def get_metric_aggregation(name):
181

lintangsutawika's avatar
lintangsutawika committed
182
183
184
185
186
187
#     try:
#         return METRIC_AGGREGATION_REGISTRY[name]
#     except KeyError:
#         eval_logger.warning(
#             "{} metric is not assigned a default aggregation!".format(name),
#         )
188
189


haileyschoelkopf's avatar
haileyschoelkopf committed
190
191
192
193
def is_higher_better(metric_name):
    try:
        return HIGHER_IS_BETTER_REGISTRY[metric_name]
    except KeyError:
194
195
196
        eval_logger.warning(
            f"higher_is_better not specified for metric '{metric_name}'!"
        )