model.py 19.3 KB
Newer Older
1
import abc
2
3
4
import hashlib
import json
import logging
haileyschoelkopf's avatar
haileyschoelkopf committed
5
import os
6
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
7

Lintang Sutawika's avatar
Lintang Sutawika committed
8
import transformers
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from sqlitedict import SqliteDict
10
from tqdm import tqdm
11

12
from lm_eval import utils
13

lintangsutawika's avatar
lintangsutawika committed
14

15
eval_logger = logging.getLogger("lm-eval")
16

17
18
T = TypeVar("T", bound="LM")

19
20

class LM(abc.ABC):
Ethan Smith's avatar
Ethan Smith committed
21
    def __init__(self) -> None:
22
23
24
25
26
        """Defines the interface that should be implemented by all LM subclasses.
        LMs are assumed to take text (strings) as input and yield strings as output
        (inputs/outputs should be tokenization-agnostic.)

        """
27
28
29
        # set rank and world size to a single process, by default.
        self._rank = 0
        self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
30
        self.cache_hook = CacheHook(None)
31
32

    @abc.abstractmethod
baberabb's avatar
baberabb committed
33
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
34
35
36
37
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other
        LM calls whenever possible.

baberabb's avatar
baberabb committed
38
39
40
        :param requests: list[Instance]
            A list of Instance objects, with property `args` which returns a tuple (context, continuation).
            `context: str`
41
42
                Context string. Implementations of LM must be able to handle an
                empty context string.
baberabb's avatar
baberabb committed
43
            `continuation: str`
44
45
46
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
                For example, context="hello" continuation=" world" is correct.
baberabb's avatar
baberabb committed
47
48

        :return: list[tuple[float, bool]]
49
            A list of pairs (logprob, isgreedy)
baberabb's avatar
baberabb committed
50
51
52
53
            `logprob: float`
                The log probability of `continuation`.
            `isgreedy`:
                Whether `continuation` would be generated by greedy sampling from `context`.
54
55
56
57
        """
        pass

    @abc.abstractmethod
Baber Abbasi's avatar
Baber Abbasi committed
58
    def loglikelihood_rolling(self, requests) -> List[float]:
59
60
61
62
63
64
65
66
67
68
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
69
            Prefix: BOS/EOS
70
71
72
            Max context length: 4
            Resulting input/prediction pairs:

73
                INPUT:  BOS   0   1   2
74
75
76
77
78
79
80
81
82
83
84
85
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

baberabb's avatar
baberabb committed
86
        :param requests: list[Instance]
87
            A list of Instance objects with property `args` which returns a tuple (context,).
88
            string: str
89
90
91
                String for which we are computing overall loglikelihood
        :return: list[tuple[float]]
            A list of tuples (logprob,)
92
            logprob: float
93
94
                The log probability of `context` conditioned on the BOS/EOS token.
                Can also be overridden for custom cases by `prefix_token_id`.
95
96
97
98
99
        """
        pass

    # TODO: Add an optional max length
    @abc.abstractmethod
100
    def generate_until(self, requests) -> List[str]:
101
102
        """Generate greedily until a stopping sequence

baberabb's avatar
baberabb committed
103
        :param requests: list[Instance]
Baber Abbasi's avatar
Baber Abbasi committed
104
            A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
105
106
            context: str
                Context string
Baber Abbasi's avatar
Baber Abbasi committed
107
108
            gen_kwargs: dict
                A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
baberabb's avatar
baberabb committed
109
        :return: list[str]
Baber Abbasi's avatar
Baber Abbasi committed
110
            A list of model generated continuations.
111
112
113
114
115
            continuation: str
                The generated continuation.
        """
        pass

KonradSzafer's avatar
KonradSzafer committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
        """
        Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.

        :param chat_history: list[dict[str, str]]
            A list of dictionaries with keys 'role' and 'content'.
            Values are strings representing the role name and the content of the message, respectively.
        :return: str
            A string representing the chat history in a format that can be used as input to the LM.
        """
        raise NotImplementedError(
            "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
        )

130
    @classmethod
131
132
133
134
135
136
137
138
139
140
141
142
143
    def create_from_arg_string(
        cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
    ) -> T:
        """
        Creates an instance of the LM class using the given argument string and additional config.

        Parameters:
        - arg_string: A string containing arguments in the format key1=value1,key2=value2.
        - additional_config: Optional dictionary containing additional configuration parameters.

        Returns:
        - Instance of the LM class.
        """
144
145
146
147
        additional_config = {} if additional_config is None else additional_config
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
haileyschoelkopf's avatar
haileyschoelkopf committed
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    @classmethod
    def create_from_arg_obj(
        cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
    ) -> T:
        """
        Creates an instance of the LM class using the given arg_obj

        Parameters:
        - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
        - additional_config: Optional dictionary containing additional configuration parameters.

        Returns:
        - Instance of the LM class.
        """

        additional_config = {} if additional_config is None else additional_config
        additional_config = {
            k: v for k, v in additional_config.items() if v is not None
        }

        return cls(**arg_dict, **additional_config)

haileyschoelkopf's avatar
haileyschoelkopf committed
171
172
173
174
175
    @property
    def rank(self):
        # used in the case of parallelism. Hardcoded to
        # ensure no errors arise using API models which do
        # not support multi-device parallelism nor expect it.
176
        return self._rank
haileyschoelkopf's avatar
haileyschoelkopf committed
177
178
179
180
181
182

    @property
    def world_size(self):
        # used in the case of parallelism. Hardcoded to
        # ensure no errors arise using API models which do
        # not support multi-device parallelism nor expect it.
183
        return self._world_size
haileyschoelkopf's avatar
haileyschoelkopf committed
184

KonradSzafer's avatar
KonradSzafer committed
185
186
187
188
189
190
191
192
193
194
    @property
    def tokenizer_name(self) -> str:
        """Must be defined for LM subclasses which implement Chat Templating.
        Should return the name of the tokenizer or chat template used.
        Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
        """
        raise NotImplementedError(
            "To use this model with chat templates, please implement the 'tokenizer_name' property."
        )

195
196
197
198
    def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
        """Returns the chat template structure for user/assistant messages if a template is provided.
        This method is intended to be overridden in a subclass to define a specific chat template format.
        For models that do not support chat templates, this method returns None by default.
KonradSzafer's avatar
KonradSzafer committed
199
        """
200
201

        return ""
KonradSzafer's avatar
KonradSzafer committed
202

Ethan Smith's avatar
Ethan Smith committed
203
    def set_cache_hook(self, cache_hook) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
204
205
206
207
208
209
210
211
212
213
        self.cache_hook = cache_hook


### SQLite-based caching of LM responses
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()


class CacheHook:
Ethan Smith's avatar
Ethan Smith committed
214
    def __init__(self, cachinglm) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
215
216
217
218
219
220
        if cachinglm is None:
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict

Ethan Smith's avatar
Ethan Smith committed
221
    def add_partial(self, attr, req, res) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
222
223
224
225
226
227
228
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


class CachingLM:
Ethan Smith's avatar
Ethan Smith committed
229
    def __init__(self, lm, cache_db) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
        self.lm = lm
        self.cache_db = cache_db
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
        self.dbdict = SqliteDict(cache_db, autocommit=True)

        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Baber Abbasi's avatar
Baber Abbasi committed
246
    def __getattr__(self, attr: str):
haileyschoelkopf's avatar
haileyschoelkopf committed
247
        lm_attr = getattr(self.lm, attr)
Baber Abbasi's avatar
Baber Abbasi committed
248
249
        if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
            eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
haileyschoelkopf's avatar
haileyschoelkopf committed
250
251
252
253
254
255
256
            return lm_attr

        def fn(requests):
            res = []
            remaining_reqs = []
            warned = False
            # figure out which ones are cached and which ones are new
257
258
259
            eval_logger.info(
                f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
            )
260
            for req in tqdm(requests, desc="Checking cached requests"):
haileyschoelkopf's avatar
haileyschoelkopf committed
261
                hsh = hash_args(attr, req.args)
262
                if attr == "generate_until" and req.args[1].get("do_sample", False):
haileyschoelkopf's avatar
haileyschoelkopf committed
263
264
265
266
                    # when we are doing non-greedy generation, don't use the cache
                    # (else every "randomly sampled" generation would be identical for repeats > 1).
                    if not warned:
                        eval_logger.warning(
267
                            f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
haileyschoelkopf's avatar
haileyschoelkopf committed
268
269
270
271
272
273
274
275
276
277
278
279
280
                        )
                        warned = True
                    res.append(None)
                    remaining_reqs.append(req)
                elif hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
281
282
283
            eval_logger.info(
                f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
            )
284
285
286
287
288
            if remaining_reqs:
                # actually run the LM on the requests that do not have cached results
                rem_res = getattr(self.lm, attr)(remaining_reqs)
            else:
                rem_res = []
haileyschoelkopf's avatar
haileyschoelkopf committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None:
                    resptr += 1

                res[resptr] = r

                # caching
                hsh = hash_args(attr, req.args)
                self.dbdict[hsh] = r
            self.dbdict.commit()

            return res

        return fn

    def get_cache_hook(self):
        return CacheHook(self)
309
310
311
312
313
314
315
316


class TemplateLM(LM):
    """
    A class acting as intermediary between the LM base class
    and boilerplate often included in other LM subclasses.
    """

317
318
    tokenizer = None

319
320
321
322
323
    @property
    @abc.abstractmethod
    def eot_token_id(self):
        pass

324
325
326
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
327
        return self.eot_token_id
328

329
    @abc.abstractmethod
Baber Abbasi's avatar
Baber Abbasi committed
330
331
332
333
    def tok_encode(self, string: str, **kwargs) -> List[int]:
        """
        Tokenize a string using the model's tokenizer and return a list of token IDs.
        """
334
335
336
        pass

    @abc.abstractmethod
Baber Abbasi's avatar
Baber Abbasi committed
337
    def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
338
339
        pass

Baber Abbasi's avatar
Baber Abbasi committed
340
341
342
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
343
344
345
346
347
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]

Lintang Sutawika's avatar
Lintang Sutawika committed
348
        model_class = getattr(self, "AUTO_MODEL_CLASS", None)
349

Lintang Sutawika's avatar
Lintang Sutawika committed
350
351
352
353
354
355
356
357
358
        if model_class == transformers.AutoModelForSeq2SeqLM:
            context_enc = self.tok_encode(context)
            continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
        else:
            whole_enc = self.tok_encode(context + continuation)
            context_enc = self.tok_encode(context)

            context_enc_len = len(context_enc)
            continuation_enc = whole_enc[context_enc_len:]
359
360
361

        return context_enc, continuation_enc

362
363
364
    def loglikelihood(
        self, requests, disable_tqdm: bool = False
    ) -> List[Tuple[float, bool]]:
365
366
367
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
368
                # BOS or EOS as context
369
                context_enc, continuation_enc = (
370
                    [self.prefix_token_id],
371
372
373
374
375
376
377
                    self.tok_encode(continuation),
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

378
        return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
379
380

    @abc.abstractmethod
381
382
    def loglikelihood_rolling(
        self, requests, disable_tqdm: bool = False
Baber Abbasi's avatar
Baber Abbasi committed
383
    ) -> List[float]:
384
385
386
        pass

    @abc.abstractmethod
387
    def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
388
        pass
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

    def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
        """
        Set and get the appropriate chat template for the model.
        This method sets the tokenizer's chat_template and returns the template string for reproducibility.

        The template selection logic is adapted from the Transformers library's `apply_chat_template`
        method in the Tokenizer class. The original implementation can be found at:
        https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687

        This method ensures that the right template is chosen based on the following:
        0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
        1. If the model's tokenizer has multiple templates:
            a. Use the specified template if it exists in the dictionary.
            b. Use the default template from the list if no specific template is provided.
            c. Raise an error if no default template exists and no specific template is provided.
        2. If the model's tokenizer has a single template or no template:
            a. Use the tokenizer's chat template if available.
            b. Fall back to the default chat template if no tokenizer chat template exists.

        Args:
            chat_template (Union[bool, str]): Specifies the chat template to use.
                - If False or None, no template is applied.
                - If True, the default or only available template is used.
                - If a string, the template with the matching name is used.

        Returns:
            Optional[str]: The selected chat template, or None if no template is applied.
        """
        if self.tokenizer is None:
            return ""

        if chat_template is False or chat_template is None:
            eval_logger.warning(
                "model.chat_template was called with the chat_template set to False or None. "
                "Therefore no chat template will be applied. Make sure this is an intended behavior."
            )
            return None

        # Convert boolean chat_template to None to ensure compatibility with the adapted logic
        if isinstance(chat_template, bool):
            chat_template = None
        using_default_template = False

        # First, handle the cases when the model has a dict of multiple templates
        template = self.tokenizer.chat_template or self.tokenizer.default_chat_template

        if isinstance(template, dict):
            using_default_dict = self.tokenizer.chat_template is None

            if chat_template is not None:
                if chat_template in template:
                    selected_template = template[chat_template]
                    if using_default_dict:
                        using_default_template = True
                else:
                    raise ValueError(
                        f"The specified chat template '{chat_template}' is not available. "
                        f"Available template names are {sorted(template.keys())}."
                    )
            else:
                # If user didn't pass a chat template, use the default template from the dict
                if "default" in template:
                    selected_template = template["default"]
                    using_default_template = True
                else:
                    raise ValueError(
                        "This model has multiple chat templates with no default specified! Please either pass a chat "
                        "template or the name of the template you wish to use to the `chat_template` argument. Available "
                        f"template names are {sorted(template.keys())}."
                    )

        # Cases when the model has a single template or no template
        else:
            # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
            if isinstance(chat_template, str):
                eval_logger.warning(
                    "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
                    "Using the tokenizer's chat template or the default template instead."
                )
            if self.tokenizer.chat_template is not None:
                selected_template = self.tokenizer.chat_template
            else:
                selected_template = self.tokenizer.default_chat_template
                using_default_template = True

        if using_default_template:
            eval_logger.warning(
                "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
                "very error-prone, because models are often trained with templates different from the class default! "
                "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
                "point any code depending on them will stop working. We recommend setting a valid chat template before "
                "then to ensure that this model continues working without issues."
            )

        return selected_template