model.py 12.1 KB
Newer Older
1
import abc
2
3
4
import hashlib
import json
import logging
haileyschoelkopf's avatar
haileyschoelkopf committed
5
import os
6
from typing import List, Optional, Tuple, Type, TypeVar
7
import transformers
8

Lintang Sutawika's avatar
Lintang Sutawika committed
9
import transformers
haileyschoelkopf's avatar
haileyschoelkopf committed
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12

13
from lm_eval import utils
14

lintangsutawika's avatar
lintangsutawika committed
15

16
eval_logger = logging.getLogger("lm-eval")
17

18
19
T = TypeVar("T", bound="LM")

20
21

class LM(abc.ABC):
Ethan Smith's avatar
Ethan Smith committed
22
    def __init__(self) -> None:
23
24
25
26
27
        """Defines the interface that should be implemented by all LM subclasses.
        LMs are assumed to take text (strings) as input and yield strings as output
        (inputs/outputs should be tokenization-agnostic.)

        """
28
29
30
        # set rank and world size to a single process, by default.
        self._rank = 0
        self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
31
        self.cache_hook = CacheHook(None)
32
33

    @abc.abstractmethod
baberabb's avatar
baberabb committed
34
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
35
36
37
38
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other
        LM calls whenever possible.

baberabb's avatar
baberabb committed
39
40
41
        :param requests: list[Instance]
            A list of Instance objects, with property `args` which returns a tuple (context, continuation).
            `context: str`
42
43
                Context string. Implementations of LM must be able to handle an
                empty context string.
baberabb's avatar
baberabb committed
44
            `continuation: str`
45
46
47
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
                For example, context="hello" continuation=" world" is correct.
baberabb's avatar
baberabb committed
48
49

        :return: list[tuple[float, bool]]
50
            A list of pairs (logprob, isgreedy)
baberabb's avatar
baberabb committed
51
52
53
54
            `logprob: float`
                The log probability of `continuation`.
            `isgreedy`:
                Whether `continuation` would be generated by greedy sampling from `context`.
55
56
57
58
        """
        pass

    @abc.abstractmethod
59
    def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
60
61
62
63
64
65
66
67
68
69
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
70
            Prefix: BOS/EOS
71
72
73
            Max context length: 4
            Resulting input/prediction pairs:

74
                INPUT:  BOS   0   1   2
75
76
77
78
79
80
81
82
83
84
85
86
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

baberabb's avatar
baberabb committed
87
        :param requests: list[Instance]
88
            A list of Instance objects with property `args` which returns a tuple (context,).
89
            string: str
90
91
92
                String for which we are computing overall loglikelihood
        :return: list[tuple[float]]
            A list of tuples (logprob,)
93
            logprob: float
94
95
                The log probability of `context` conditioned on the BOS/EOS token.
                Can also be overridden for custom cases by `prefix_token_id`.
96
97
98
99
100
        """
        pass

    # TODO: Add an optional max length
    @abc.abstractmethod
101
    def generate_until(self, requests) -> List[str]:
102
103
        """Generate greedily until a stopping sequence

baberabb's avatar
baberabb committed
104
105
        :param requests: list[Instance]
            A list of Instance objects with property `args` which returns a tuple (context, until).
106
107
108
109
110
            context: str
                Context string
            until: [str]
                The string sequences to generate until. These string sequences
                may each span across multiple tokens, or may be part of one token.
baberabb's avatar
baberabb committed
111
        :return: list[str]
112
113
114
115
116
117
118
            A list of strings continuation
            continuation: str
                The generated continuation.
        """
        pass

    @classmethod
119
120
121
122
123
124
125
126
127
128
129
130
131
    def create_from_arg_string(
        cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
    ) -> T:
        """
        Creates an instance of the LM class using the given argument string and additional config.

        Parameters:
        - arg_string: A string containing arguments in the format key1=value1,key2=value2.
        - additional_config: Optional dictionary containing additional configuration parameters.

        Returns:
        - Instance of the LM class.
        """
132
133
134
135
        additional_config = {} if additional_config is None else additional_config
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
haileyschoelkopf's avatar
haileyschoelkopf committed
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    @classmethod
    def create_from_arg_obj(
        cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
    ) -> T:
        """
        Creates an instance of the LM class using the given arg_obj

        Parameters:
        - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
        - additional_config: Optional dictionary containing additional configuration parameters.

        Returns:
        - Instance of the LM class.
        """

        additional_config = {} if additional_config is None else additional_config
        additional_config = {
            k: v for k, v in additional_config.items() if v is not None
        }

        return cls(**arg_dict, **additional_config)

haileyschoelkopf's avatar
haileyschoelkopf committed
159
160
161
162
163
    @property
    def rank(self):
        # used in the case of parallelism. Hardcoded to
        # ensure no errors arise using API models which do
        # not support multi-device parallelism nor expect it.
164
        return self._rank
haileyschoelkopf's avatar
haileyschoelkopf committed
165
166
167
168
169
170

    @property
    def world_size(self):
        # used in the case of parallelism. Hardcoded to
        # ensure no errors arise using API models which do
        # not support multi-device parallelism nor expect it.
171
        return self._world_size
haileyschoelkopf's avatar
haileyschoelkopf committed
172

Ethan Smith's avatar
Ethan Smith committed
173
    def set_cache_hook(self, cache_hook) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
174
175
176
177
178
179
180
181
182
183
        self.cache_hook = cache_hook


### SQLite-based caching of LM responses
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()


class CacheHook:
Ethan Smith's avatar
Ethan Smith committed
184
    def __init__(self, cachinglm) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
185
186
187
188
189
190
        if cachinglm is None:
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict

Ethan Smith's avatar
Ethan Smith committed
191
    def add_partial(self, attr, req, res) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
192
193
194
195
196
197
198
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


class CachingLM:
Ethan Smith's avatar
Ethan Smith committed
199
    def __init__(self, lm, cache_db) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
        self.lm = lm
        self.cache_db = cache_db
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
        self.dbdict = SqliteDict(cache_db, autocommit=True)

        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

    def __getattr__(self, attr):
        lm_attr = getattr(self.lm, attr)
        if not callable(lm_attr):
            return lm_attr

        def fn(requests):
            res = []
            remaining_reqs = []
            warned = False
            # figure out which ones are cached and which ones are new
226
227
228
            eval_logger.info(
                f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
            )
229
            for req in tqdm(requests, desc="Checking cached requests"):
haileyschoelkopf's avatar
haileyschoelkopf committed
230
                hsh = hash_args(attr, req.args)
231
                if attr == "generate_until" and req.args[1].get("do_sample", False):
haileyschoelkopf's avatar
haileyschoelkopf committed
232
233
234
235
                    # when we are doing non-greedy generation, don't use the cache
                    # (else every "randomly sampled" generation would be identical for repeats > 1).
                    if not warned:
                        eval_logger.warning(
236
                            f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
haileyschoelkopf's avatar
haileyschoelkopf committed
237
238
239
240
241
242
243
244
245
246
247
248
249
                        )
                        warned = True
                    res.append(None)
                    remaining_reqs.append(req)
                elif hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
250
251
252
            eval_logger.info(
                f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            # actually run the LM on the requests that do not have cached results
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None:
                    resptr += 1

                res[resptr] = r

                # caching
                hsh = hash_args(attr, req.args)
                self.dbdict[hsh] = r
            self.dbdict.commit()

            return res

        return fn

    def get_cache_hook(self):
        return CacheHook(self)
275
276
277
278
279
280
281
282
283
284
285
286
287


class TemplateLM(LM):
    """
    A class acting as intermediary between the LM base class
    and boilerplate often included in other LM subclasses.
    """

    @property
    @abc.abstractmethod
    def eot_token_id(self):
        pass

288
289
290
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
291
        return self.eot_token_id
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @abc.abstractmethod
    def tok_encode(self, string: str, **kwargs):
        pass

    @abc.abstractmethod
    def _loglikelihood_tokens(self, requests, **kwargs):
        pass

    def _encode_pair(self, context, continuation):
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]

Lintang Sutawika's avatar
Lintang Sutawika committed
307
        model_class = getattr(self, "AUTO_MODEL_CLASS", None)
308

Lintang Sutawika's avatar
Lintang Sutawika committed
309
        if model_class == transformers.AutoModelForSeq2SeqLM:
lintangsutawika's avatar
seq2seq  
lintangsutawika committed
310
311
            context_enc = self.tok_encode(context)
            continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
Lintang Sutawika's avatar
Lintang Sutawika committed
312
313
314
315
316
317
        else:
            whole_enc = self.tok_encode(context + continuation)
            context_enc = self.tok_encode(context)

            context_enc_len = len(context_enc)
            continuation_enc = whole_enc[context_enc_len:]
318
319
320

        return context_enc, continuation_enc

321
322
323
    def loglikelihood(
        self, requests, disable_tqdm: bool = False
    ) -> List[Tuple[float, bool]]:
324
325
326
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
327
                # BOS or EOS as context
328
                context_enc, continuation_enc = (
329
                    [self.prefix_token_id],
330
331
332
333
334
335
336
                    self.tok_encode(continuation),
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

337
        return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
338
339

    @abc.abstractmethod
340
341
342
    def loglikelihood_rolling(
        self, requests, disable_tqdm: bool = False
    ) -> List[Tuple[float, bool]]:
343
344
345
        pass

    @abc.abstractmethod
346
    def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
347
        pass