vllm_causallms.py 12.9 KB
Newer Older
baberabb's avatar
baberabb committed
1
2
3
4
5
6
7
8
9
from collections import defaultdict
from typing import List, Tuple, Optional, Literal

from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
baberabb's avatar
baberabb committed
10
from vllm import LLM, SamplingParams
baberabb's avatar
baberabb committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


@register_model("vllm")
class VLLM(LM):
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
        pretrained="gpt2",
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
        tokenizer_mode: Literal["auto", "slow"] = "auto",
        tensor_parallel_size: int = 1,
        quantization: Optional[str] = None,
        max_gen_toks: int = 256,
        swap_space: int = 4,
        batch_size: int = 1,
        max_length: int = None,
    ):
        super().__init__()

        self.model = LLM(
            model=pretrained,
baberabb's avatar
baberabb committed
35
            gpu_memory_utilization=0.9,
baberabb's avatar
baberabb committed
36
37
38
39
            revision=revision,
            dtype=dtype,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
baberabb's avatar
baberabb committed
40
            tensor_parallel_size=int(tensor_parallel_size),
baberabb's avatar
baberabb committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            swap_space=swap_space,
            quantization=quantization,
        )
        self.tokenizer = self.model.get_tokenizer()
        self.batch_size = batch_size
        self._max_length = max_length
        self._max_gen_toks = max_gen_toks

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        if hasattr(self.model.llm_engine.model_config, "max_model_len"):
            return self.model.llm_engine.model_config.max_model_len
        return self._DEFAULT_MAX_LENGTH

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

baberabb's avatar
baberabb committed
66
67
68
69
70
71
72
    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=False,
        truncation=False,
    ):
baberabb's avatar
baberabb committed
73
        """ """
baberabb's avatar
baberabb committed
74
75
76
        encoding = self.tokenizer.encode(
            string, add_special_tokens=add_special_tokens, truncation=truncation
        )
baberabb's avatar
baberabb committed
77
78
79
80
81
82
83
84
85

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
86
        requests: List[int] = None,
baberabb's avatar
baberabb committed
87
88
89
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
baberabb's avatar
baberabb committed
90
        use_tqdm=True,
baberabb's avatar
baberabb committed
91
92
        **kwargs,
    ):
baberabb's avatar
bugfix  
baberabb committed
93
94
        if "do_sample" in kwargs.keys():
            kwargs.pop("do_sample")
baberabb's avatar
baberabb committed
95
96
97
98
99
100
101
        if generate:
            generate_sampling_params = SamplingParams(
                max_tokens=max_tokens, stop=stop, **kwargs
            )
            outputs = self.model.generate(
                prompt_token_ids=requests,
                sampling_params=generate_sampling_params,
baberabb's avatar
baberabb committed
102
                use_tqdm=use_tqdm,
baberabb's avatar
baberabb committed
103
104
105
106
107
108
            )
        else:
            logliklihood_sampling_params = SamplingParams(
                temperature=0, prompt_logprobs=2, max_tokens=1
            )
            outputs = self.model.generate(
baberabb's avatar
baberabb committed
109
110
111
                prompt_token_ids=requests,
                sampling_params=logliklihood_sampling_params,
                use_tqdm=use_tqdm,
baberabb's avatar
baberabb committed
112
113
114
            )
        return outputs

baberabb's avatar
baberabb committed
115
    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
116
117
118
119
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
baberabb's avatar
baberabb committed
120
121
122
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
baberabb's avatar
baberabb committed
123
124
125
126
127
128
129
130
131
132
133
134
            else:
                context_enc, continuation_enc = self.tokenizer(
                    [context, continuation],
                    truncation="do_not_truncate",
                    add_special_tokens=False,
                    return_attention_mask=False,
                ).input_ids

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

baberabb's avatar
baberabb committed
135
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
baberabb's avatar
baberabb committed
136
137
138
139
140
141
142
143
144
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
baberabb's avatar
baberabb committed
145
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods

    def generate_until(self, requests: List[Instance]) -> List[str]:
        res = defaultdict(list)
        re_ords = {}

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
baberabb's avatar
bugfix  
baberabb committed
170
        context_encoding = self.tokenizer(context).input_ids
baberabb's avatar
baberabb committed
171
172
173
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            return -len(_requests[0][1]), tuple(_requests[0][1])

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        grouper = utils.Grouper(requests, lambda x: str(x[1]))
        for key, reqs in grouper.get_grouped().items():
            # within each set of reqs for given kwargs, we reorder by token length, descending.
            re_ords[key] = utils.Reorderer(requests, _collate_gen)

        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        # for each different set of kwargs, we execute all requests, by batch.
        for key, re_ord in re_ords.items():
            chunks = utils.chunks(
                re_ord.get_reordered(),
                n=self.batch_size,
                fn=None,
            )
            for chunk in chunks:
baberabb's avatar
bugfix  
baberabb committed
201
                context_and_encoding, all_gen_kwargs = zip(*chunk)
baberabb's avatar
baberabb committed
202
                context, context_encoding = zip(*context_and_encoding)
baberabb's avatar
baberabb committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [until]
                        elif not isinstance(until, list):
                            raise ValueError(
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                            )
                else:
                    raise ValueError(
                        f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
                    )
                if not until:
                    until = [self.tokenizer.decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks

                # set the max length in tokens of inputs ("context_enc")
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
                context_encoding = [x[-max_ctx_len:] for x in context_encoding]

                # TODO: max_length in kwargs

                # perform batched generation
                cont = self._model_generate(
                    requests=context_encoding,
                    generate=True,
                    max_tokens=max_gen_toks,
                    stop=until,
                    **kwargs,
                )

                # cache generations
                for output, context in zip(cont, context):
                    generated_text = output.outputs[0].text
                    res[key].append(generated_text)
                    self.cache_hook.add_partial(
                        "generate_until", (context, gen_kwargs), generated_text
                    )
                    pbar.update(1)

            # reorder this group of results back to original unsorted form
            res[key] = re_ord.get_original(res[key])

        pbar.close()

        return grouper.get_original(res)

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
262
263
264
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        re_ord = utils.Reorderer(requests, _collate)

        chunks = utils.chunks(
            re_ord.get_reordered(),
            n=self.batch_size,
            fn=None,
        )
baberabb's avatar
baberabb committed
279
        pbar = tqdm(total=len(requests), disable=disable_tqdm)
baberabb's avatar
baberabb committed
280
281
282
283
284
285
286
287
288
289
290
291
        for chunk in chunks:
            inps = []
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

                inps.append(inp)
                ctxlens.append(ctxlen)

baberabb's avatar
baberabb committed
292
            outputs = self._model_generate(requests=inps, generate=False)
baberabb's avatar
baberabb committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

            for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                outputs, ctxlens, chunk
            ):
                answer = self._parse_logprobs(
                    (context_enc + continuation_enc),
                    output,
                    ctxlen,
                )

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
                    pbar.update(1)
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
313
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
314
315
316
        """Process logprobs and tokens.

        :param tokens: list
baberabb's avatar
baberabb committed
317
            Tokens from context+continuations
baberabb's avatar
bugfix  
baberabb committed
318
319
        :param outputs: RequestOutput
            Contains prompt
baberabb's avatar
baberabb committed
320
321
322
323
324
325
326
327
328
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

baberabb's avatar
baberabb committed
329
        # prompt_logprobs = [None, {}*len(context-1)]
baberabb's avatar
bugfix  
baberabb committed
330
331
        continuation_logprobs_dicts = outputs.prompt_logprobs

baberabb's avatar
baberabb committed
332
        # Calculate continuation_logprobs
baberabb's avatar
baberabb committed
333
        # assume ctxlen always > 1
baberabb's avatar
baberabb committed
334
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
335
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
336
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
337
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
338
339
340
341
342
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
343
344
345
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
346
347
348
349
350
351
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
352
353

        return continuation_logprobs, is_greedy