registry.py 5.64 KB
Newer Older
Baber's avatar
cleanup  
Baber committed
1
2
from __future__ import annotations

3
import logging
Baber's avatar
cleanup  
Baber committed
4
from typing import TYPE_CHECKING, Any, Callable
5

6

Baber's avatar
Baber committed
7
8
if TYPE_CHECKING:
    from lm_eval.api.model import LM
lintangsutawika's avatar
lintangsutawika committed
9

Lintang Sutawika's avatar
Lintang Sutawika committed
10
eval_logger = logging.getLogger(__name__)
11
12

MODEL_REGISTRY = {}
Baber's avatar
cleanup  
Baber committed
13
14
DEFAULTS = {
    "model": {"max_length": 2048},
Baber's avatar
Baber committed
15
    "tasks": {"generate_until": {"max_gen_toks": 256}},
Baber's avatar
cleanup  
Baber committed
16
}
17
18
19


def register_model(*names):
Baber's avatar
Baber committed
20
21
    from lm_eval.api.model import LM

22
23
24
25
26
    # either pass a list or a single alias.
    # function receives them as a tuple of strings

    def decorate(cls):
        for name in names:
Baber Abbasi's avatar
Baber Abbasi committed
27
28
29
            assert issubclass(cls, LM), (
                f"Model '{name}' ({cls.__name__}) must extend LM class"
            )
30

Baber Abbasi's avatar
Baber Abbasi committed
31
32
33
            assert name not in MODEL_REGISTRY, (
                f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
            )
34
35
36
37
38
39
40

            MODEL_REGISTRY[name] = cls
        return cls

    return decorate


Baber's avatar
cleanup  
Baber committed
41
def get_model(model_name: str) -> type[LM]:
haileyschoelkopf's avatar
haileyschoelkopf committed
42
43
    try:
        return MODEL_REGISTRY[model_name]
Baber's avatar
cleanup  
Baber committed
44
45
46
47
48
    except KeyError as err:
        available_models = ", ".join(MODEL_REGISTRY.keys())
        raise KeyError(
            f"Model '{model_name}' not found. Available models: {available_models}"
        ) from err
49
50
51
52


TASK_REGISTRY = {}
GROUP_REGISTRY = {}
53
ALL_TASKS = set()
54
55
56
func2task_index = {}


Baber's avatar
Baber committed
57
def register_task(name: str):
58
    def decorate(fn):
Baber Abbasi's avatar
Baber Abbasi committed
59
60
61
        assert name not in TASK_REGISTRY, (
            f"task named '{name}' conflicts with existing registered task!"
        )
62
63

        TASK_REGISTRY[name] = fn
64
        ALL_TASKS.add(name)
65
66
67
68
69
70
71
72
73
74
75
76
77
        func2task_index[fn.__name__] = name
        return fn

    return decorate


def register_group(name):
    def decorate(fn):
        func_name = func2task_index[fn.__name__]
        if name in GROUP_REGISTRY:
            GROUP_REGISTRY[name].append(func_name)
        else:
            GROUP_REGISTRY[name] = [func_name]
78
            ALL_TASKS.add(name)
79
80
81
82
83
84
        return fn

    return decorate


OUTPUT_TYPE_REGISTRY = {}
85
86
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
Baber's avatar
cleanup  
Baber committed
87
AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
88
HIGHER_IS_BETTER_REGISTRY = {}
89
FILTER_REGISTRY = {}
90
91
92
93
94
95
96

DEFAULT_METRIC_REGISTRY = {
    "loglikelihood": [
        "perplexity",
        "acc",
    ],
    "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
97
    "multiple_choice": ["acc", "acc_norm"],
98
    "generate_until": ["exact_match"],
99
100
101
102
103
104
105
106
107
108
109
110
}


def register_metric(**args):
    # TODO: do we want to enforce a certain interface to registered metrics?
    def decorate(fn):
        assert "metric" in args
        name = args["metric"]

        for key, registry in [
            ("metric", METRIC_REGISTRY),
            ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
111
            ("aggregation", METRIC_AGGREGATION_REGISTRY),
112
113
114
        ]:
            if key in args:
                value = args[key]
Baber Abbasi's avatar
Baber Abbasi committed
115
116
117
                assert value not in registry, (
                    f"{key} named '{value}' conflicts with existing registered {key}!"
                )
118
119
120
121
122
123
124
125
126
127
128
129
130

                if key == "metric":
                    registry[name] = fn
                elif key == "aggregation":
                    registry[name] = AGGREGATION_REGISTRY[value]
                else:
                    registry[name] = value

        return fn

    return decorate


Baber's avatar
cleanup  
Baber committed
131
def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
132
133
134
135
136
137
138
    if not hf_evaluate_metric:
        if name in METRIC_REGISTRY:
            return METRIC_REGISTRY[name]
        else:
            eval_logger.warning(
                f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
            )
Chris's avatar
Chris committed
139

140
    try:
Baber's avatar
Baber committed
141
142
        import evaluate as hf_evaluate

Baber Abbasi's avatar
Baber Abbasi committed
143
        metric_object = hf_evaluate.load(name)
144
145
146
147
        return metric_object.compute
    except Exception:
        eval_logger.error(
            f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
148
149
150
        )


Baber Abbasi's avatar
Baber Abbasi committed
151
def register_aggregation(name: str):
152
    def decorate(fn):
Baber Abbasi's avatar
Baber Abbasi committed
153
154
155
        assert name not in AGGREGATION_REGISTRY, (
            f"aggregation named '{name}' conflicts with existing registered aggregation!"
        )
156
157
158
159
160
161
162

        AGGREGATION_REGISTRY[name] = fn
        return fn

    return decorate


Baber's avatar
cleanup  
Baber committed
163
def get_aggregation(name: str) -> Callable[..., Any] | None:
164
165
166
    try:
        return AGGREGATION_REGISTRY[name]
    except KeyError:
167
        eval_logger.warning(f"{name} not a registered aggregation metric!")
haileyschoelkopf's avatar
haileyschoelkopf committed
168
169


170
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable[..., Any]]]:
171
172
173
    try:
        return METRIC_AGGREGATION_REGISTRY[name]
    except KeyError:
174
175
176
177
        eval_logger.warning(
            f"{name} metric is not assigned a default aggregation!. Using default aggregation mean"
        )
        return AGGREGATION_REGISTRY["mean"]
178
179


180
def is_higher_better(metric_name: str) -> bool:
haileyschoelkopf's avatar
haileyschoelkopf committed
181
182
183
    try:
        return HIGHER_IS_BETTER_REGISTRY[metric_name]
    except KeyError:
184
        eval_logger.warning(
185
            f"higher_is_better not specified for metric '{metric_name}'!. Will default to True."
186
        )
187
        return True
188
189


Baber's avatar
cleanup  
Baber committed
190
def register_filter(name: str):
191
192
193
194
195
196
197
198
199
200
201
    def decorate(cls):
        if name in FILTER_REGISTRY:
            eval_logger.info(
                f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
            )
        FILTER_REGISTRY[name] = cls
        return cls

    return decorate


Baber's avatar
cleanup  
Baber committed
202
def get_filter(filter_name: str | Callable) -> Callable:
203
204
    try:
        return FILTER_REGISTRY[filter_name]
205
206
207
208
209
210
    except KeyError as e:
        if callable(filter_name):
            return filter_name
        else:
            eval_logger.warning(f"filter `{filter_name}` is not registered!")
            raise e