registry.py 4.99 KB
Newer Older
1
2
3
import os
import evaluate
from lm_eval.api.model import LM
4
5

import logging
lintangsutawika's avatar
lintangsutawika committed
6

7
eval_logger = logging.getLogger("lm-eval")
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

MODEL_REGISTRY = {}


def register_model(*names):
    # either pass a list or a single alias.
    # function receives them as a tuple of strings

    def decorate(cls):
        for name in names:
            assert issubclass(
                cls, LM
            ), f"Model '{name}' ({cls.__name__}) must extend LM class"

            assert (
                name not in MODEL_REGISTRY
            ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."

            MODEL_REGISTRY[name] = cls
        return cls

    return decorate


def get_model(model_name):
haileyschoelkopf's avatar
haileyschoelkopf committed
33
34
35
    try:
        return MODEL_REGISTRY[model_name]
    except KeyError:
36
37
38
        raise ValueError(
            f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
        )
39
40
41
42


TASK_REGISTRY = {}
GROUP_REGISTRY = {}
43
ALL_TASKS = set()
44
45
46
47
48
49
50
51
52
53
func2task_index = {}


def register_task(name):
    def decorate(fn):
        assert (
            name not in TASK_REGISTRY
        ), f"task named '{name}' conflicts with existing registered task!"

        TASK_REGISTRY[name] = fn
54
        ALL_TASKS.add(name)
55
56
57
58
59
60
61
62
63
64
65
66
67
        func2task_index[fn.__name__] = name
        return fn

    return decorate


def register_group(name):
    def decorate(fn):
        func_name = func2task_index[fn.__name__]
        if name in GROUP_REGISTRY:
            GROUP_REGISTRY[name].append(func_name)
        else:
            GROUP_REGISTRY[name] = [func_name]
68
            ALL_TASKS.add(name)
69
70
71
72
73
74
        return fn

    return decorate


OUTPUT_TYPE_REGISTRY = {}
75
76
77
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
78
79
HIGHER_IS_BETTER_REGISTRY = {}

80
81
82
83
84
85
86
87
88
89
# DEFAULT_METRIC_REGISTRY = {
#     "loglikelihood": [
#         "perplexity",
#         "acc",
#     ],
#     "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
#     "multiple_choice": ["acc", "acc_norm"],
#     "generate_until": ["exact_match"],
# }

90
DEFAULT_METRIC_REGISTRY = {
91
92
93
94
    "loglikelihood": [],
    "loglikelihood_rolling": [],
    "multiple_choice": [],
    "generate_until": [],
95
96
97
}


98
99
100
101
102
103
def register_metric(
    metric,
    higher_is_better=None,
    output_type=None,
    aggregation=None,
):
104
105
106
    # TODO: do we want to enforce a certain interface to registered metrics?
    def decorate(fn):

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        METRIC_REGISTRY[metric] = fn(aggregation=aggregation)

        if higher_is_better is not None:
            HIGHER_IS_BETTER_REGISTRY[metric] = higher_is_better
        if output_type is not None:
            DEFAULT_METRIC_REGISTRY[output_type].append(metric)

        # for key, registry in [
        #     ("output_type", OUTPUT_TYPE_REGISTRY),
        #     ("metric", METRIC_REGISTRY),
        #     ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
        #     ("aggregation", METRIC_AGGREGATION_REGISTRY),
        # ]:

        #     if key in args:
        #         value = args[key]
        #         assert (
        #             value not in registry
        #         ), f"{key} named '{value}' conflicts with existing registered {key}!"

        #         if key == "metric":
        #             registry[name] = fn
        #         elif key == "aggregation":
        #             registry[name] = AGGREGATION_REGISTRY[value]
        #         else:
        #             registry[name] = value
133
134
135
136
137
138

        return fn

    return decorate


139
def get_metric(name, hf_evaluate_metric=False):
140

141
142
143
144
145
146
147
    if not hf_evaluate_metric:
        if name in METRIC_REGISTRY:
            return METRIC_REGISTRY[name]
        else:
            eval_logger.warning(
                f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
            )
Chris's avatar
Chris committed
148

149
    try:
150
151
152
153
154
        metric_object = evaluate.load(name)
        return metric_object.compute
    except Exception:
        eval_logger.error(
            f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        )


def register_aggregation(name):
    def decorate(fn):
        assert (
            name not in AGGREGATION_REGISTRY
        ), f"aggregation named '{name}' conflicts with existing registered aggregation!"

        AGGREGATION_REGISTRY[name] = fn
        return fn

    return decorate


def get_aggregation(name):

    try:
        return AGGREGATION_REGISTRY[name]
    except KeyError:
175
        eval_logger.warning(
176
177
            "{} not a registered aggregation metric!".format(name),
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
178
179


180
181
182
183
184
185
def get_metric_aggregation(name):

    try:
        return METRIC_AGGREGATION_REGISTRY[name]
    except KeyError:
        eval_logger.warning(
lintangsutawika's avatar
lintangsutawika committed
186
            "{} metric is not assigned a default aggregation!".format(name),
187
188
189
        )


haileyschoelkopf's avatar
haileyschoelkopf committed
190
191
192
193
def is_higher_better(metric_name):
    try:
        return HIGHER_IS_BETTER_REGISTRY[metric_name]
    except KeyError:
194
195
196
        eval_logger.warning(
            f"higher_is_better not specified for metric '{metric_name}'!"
        )