vllm_causallms.py 16.4 KB
Newer Older
1
2
3
4
5
6
import copy
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union

from tqdm import tqdm

baberabb's avatar
baberabb committed
7
8
9
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
10
11
12
13
14
15
16
from lm_eval.utils import (
    Collator,
    divide,
    eval_logger,
    get_rolling_token_windows,
    make_disjoint_window,
)
17

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
18

19
try:
20
    import ray
baberabb's avatar
baberabb committed
21
    from ray.util.multiprocessing import Pool
22
    from vllm import LLM, SamplingParams
baberabb's avatar
baberabb committed
23
    from vllm.transformers_utils.tokenizer import get_tokenizer
24
25
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
26

27
eval_logger = eval_logger
baberabb's avatar
baberabb committed
28

baberabb's avatar
baberabb committed
29

baberabb's avatar
baberabb committed
30
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
31
32
33
def run_inference_one_model(
    model_args: dict, sampling_params, requests: List[List[int]]
):
baberabb's avatar
baberabb committed
34
35
36
37
    llm = LLM(**model_args)
    return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)


baberabb's avatar
baberabb committed
38
39
40
41
42
43
44
45
46
47
@register_model("vllm")
class VLLM(LM):
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
        pretrained="gpt2",
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
48
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
49
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
50
        tokenizer_revision: Optional[str] = None,
baberabb's avatar
baberabb committed
51
        tensor_parallel_size: int = 1,
52
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
53
54
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
55
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
56
        max_batch_size=None,
baberabb's avatar
baberabb committed
57
        max_length: int = None,
58
        max_model_len: int = None,
baberabb's avatar
baberabb committed
59
        seed: int = 1234,
60
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
61
        device: str = "cuda",
62
        data_parallel_size: int = 1,
baberabb's avatar
baberabb committed
63
64
    ):
        super().__init__()
65

66
        if not find_spec("vllm"):
67
            raise Exception(
68
69
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
70
71
            )

baberabb's avatar
baberabb committed
72
        assert "cuda" in device or device is None, "vLLM only supports CUDA"
73
74
75
76
77
        assert (
            max_length is None or max_model_len is None
        ), "Either max_length or max_model_len may be provided, but not both"

        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
78
        self.tensor_parallel_size = int(tensor_parallel_size)
79
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
80
81
82
83
84
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
85
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
86
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
87
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
88
89
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
90
            "max_model_len": int(self._max_length) if self._max_length else None,
baberabb's avatar
baberabb committed
91
92
93
94
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
        }
95
96
97
98
99
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
            else batch_size
        )
100
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
101
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
102
103
        else:
            self.model_args["worker_use_ray"] = True
104
105
106
107
108
109
110
111
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

            from transformers import AutoConfig

            self._config = AutoConfig.from_pretrained(
                pretrained, trust_remote_code=trust_remote_code, revision=revision
            )
baberabb's avatar
nits  
baberabb committed
112
113
114
115
116
117
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
        )
118

baberabb's avatar
baberabb committed
119
120
121
122
123
124
125
126
127
128
129
        self._max_gen_toks = max_gen_toks

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
130
131
132
133
134
135
136
137
138
139
140
141
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
142
143
144
145
146

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

baberabb's avatar
baberabb committed
147
148
149
150
151
152
153
    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=False,
        truncation=False,
    ):
baberabb's avatar
baberabb committed
154
        """ """
baberabb's avatar
baberabb committed
155
156
157
        encoding = self.tokenizer.encode(
            string, add_special_tokens=add_special_tokens, truncation=truncation
        )
baberabb's avatar
baberabb committed
158
159
160
161
162
163
164
165
166

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
167
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
168
169
170
171
172
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
baberabb's avatar
bugfix  
baberabb committed
173
174
        if "do_sample" in kwargs.keys():
            kwargs.pop("do_sample")
baberabb's avatar
baberabb committed
175
        if generate:
176
177
178
179
180
            # hf defaults
            kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
            kwargs["spaces_between_special_tokens"] = kwargs.get(
                "spaces_between_special_tokens", False
            )
baberabb's avatar
baberabb committed
181
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
182
        else:
baberabb's avatar
baberabb committed
183
            sampling_params = SamplingParams(
baberabb's avatar
baberabb committed
184
185
                temperature=0, prompt_logprobs=2, max_tokens=1
            )
186
        if self.data_parallel_size > 1:
187
            requests = [list(x) for x in divide(requests, self.data_parallel_size)]
baberabb's avatar
baberabb committed
188
            inputs = [(self.model_args, sampling_params, req) for req in requests]
baberabb's avatar
baberabb committed
189

190
            with Pool(self.data_parallel_size) as pool:
baberabb's avatar
baberabb committed
191
                results = pool.starmap(run_inference_one_model, inputs)
192
193
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
194
195
196
197
198
199
            # flatten results
            return [item for sublist in results for item in sublist]

        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
200
            use_tqdm=True if self.batch_size == "auto" else False,
baberabb's avatar
baberabb committed
201
        )
baberabb's avatar
baberabb committed
202
203
        return outputs

baberabb's avatar
baberabb committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]

        whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
        context_enc = self.tok_encode(context, add_special_tokens=False)

        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

baberabb's avatar
baberabb committed
219
    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
220
221
222
223
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
224
225
226
                context_enc, continuation_enc = (
                    [self.eot_token_id],
                    self.tok_encode(continuation),
baberabb's avatar
baberabb committed
227
                )
baberabb's avatar
baberabb committed
228
            else:
baberabb's avatar
baberabb committed
229
                context_enc, continuation_enc = self._encode_pair(context, continuation)
baberabb's avatar
baberabb committed
230
231
232
233
234

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

baberabb's avatar
baberabb committed
235
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
baberabb's avatar
baberabb committed
236
237
238
239
240
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
241
242
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
243
244
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
baberabb's avatar
baberabb committed
245
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods

    def generate_until(self, requests: List[Instance]) -> List[str]:
265
        res = []
baberabb's avatar
baberabb committed
266
267
268

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
269
        context_encoding = self.tokenizer(context, add_special_tokens=False).input_ids
baberabb's avatar
baberabb committed
270
271
272
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
273
274
275
276
277
278
279
280

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
281
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
282
283
284
285

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
286
287
288
289
        re_ords = Collator(requests, _collate_gen, grouping=True)
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
290
291
292

        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        # for each different set of kwargs, we execute all requests, by batch.
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [until]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
baberabb's avatar
baberabb committed
314
                )
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            if not until:
                until = [self.tokenizer.decode(self.eot_token_id)]
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
335

336
337
338
339
340
341
342
343
            # cache generations
            for output, context in zip(cont, context):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
344
345

        pbar.close()
346
347
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
348
349

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
350
351
352
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
353
354
355
356
357
358
359
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

360
361
362
363
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
364
        )
365

baberabb's avatar
baberabb committed
366
        pbar = tqdm(total=len(requests), disable=disable_tqdm)
baberabb's avatar
baberabb committed
367
        for chunk in chunks:
368
            inputs = []
baberabb's avatar
baberabb committed
369
370
371
372
373
374
375
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

376
                inputs.append(inp)
baberabb's avatar
baberabb committed
377
378
                ctxlens.append(ctxlen)

379
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
380

381
382
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
383
384
            ):
                answer = self._parse_logprobs(
385
386
387
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
388
389
390
391
392
393
394
                )

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
395
                pbar.update(1)
baberabb's avatar
baberabb committed
396
397
398
399
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
400
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
401
402
403
        """Process logprobs and tokens.

        :param tokens: list
404
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
405
        :param outputs: RequestOutput
406
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
407
408
409
410
411
412
413
414
415
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

416
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
417
418
        continuation_logprobs_dicts = outputs.prompt_logprobs

baberabb's avatar
baberabb committed
419
        # Calculate continuation_logprobs
420
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
421
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
422
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
423
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
424
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
425
426
427
428
429
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
430
431
432
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
433
434
435
436
437
438
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
439
440

        return continuation_logprobs, is_greedy