vllm_causallms.py 17 KB
Newer Older
1
import copy
Baber Abbasi's avatar
Baber Abbasi committed
2
from importlib.metadata import version
3
4
5
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union

6
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
7
from packaging.version import parse as parse_version
8
9
from tqdm import tqdm

baberabb's avatar
baberabb committed
10
from lm_eval.api.instance import Instance
11
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
12
from lm_eval.api.registry import register_model
13
from lm_eval.models.utils import Collator, undistribute
14
15
16
17
18
from lm_eval.utils import (
    eval_logger,
    get_rolling_token_windows,
    make_disjoint_window,
)
19

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
20

21
try:
22
    import ray
23
    from vllm import LLM, SamplingParams
baberabb's avatar
baberabb committed
24
    from vllm.transformers_utils.tokenizer import get_tokenizer
25
26
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
27

28
eval_logger = eval_logger
baberabb's avatar
baberabb committed
29

baberabb's avatar
baberabb committed
30
31

@register_model("vllm")
32
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
33
34
35
36
37
38
39
40
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
        pretrained="gpt2",
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
41
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
42
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
43
        tokenizer_revision: Optional[str] = None,
44
        add_bos_token: Optional[bool] = False,
baberabb's avatar
baberabb committed
45
        tensor_parallel_size: int = 1,
46
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
47
48
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
49
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
50
        max_batch_size=None,
baberabb's avatar
baberabb committed
51
        max_length: int = None,
52
        max_model_len: int = None,
baberabb's avatar
baberabb committed
53
        seed: int = 1234,
54
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
55
        device: str = "cuda",
56
        data_parallel_size: int = 1,
Baber Abbasi's avatar
Baber Abbasi committed
57
        **kwargs,
baberabb's avatar
baberabb committed
58
59
    ):
        super().__init__()
60

61
        if not find_spec("vllm"):
62
            raise Exception(
63
64
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
65
66
            )

baberabb's avatar
baberabb committed
67
        assert "cuda" in device or device is None, "vLLM only supports CUDA"
68
69
70
71
72
        assert (
            max_length is None or max_model_len is None
        ), "Either max_length or max_model_len may be provided, but not both"

        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
73
        self.tensor_parallel_size = int(tensor_parallel_size)
74
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
75
76
77
78
79
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
80
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
81
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
82
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
83
84
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
85
            "max_model_len": int(self._max_length) if self._max_length else None,
baberabb's avatar
baberabb committed
86
87
88
89
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
        }
Baber Abbasi's avatar
Baber Abbasi committed
90
        self.model_args.update(kwargs)
91
92
93
94
95
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
            else batch_size
        )
96
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
97
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
98
        else:
Baber Abbasi's avatar
Baber Abbasi committed
99
100
101
102
103
104
            assert parse_version(version("vllm")) < parse_version(
                "0.3.3"
            ), "data_parallel is only compatible with vllm < v0.3.3."
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
baberabb's avatar
baberabb committed
105
            self.model_args["worker_use_ray"] = True
106
107
108
109
110
111
112
113
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

            from transformers import AutoConfig

            self._config = AutoConfig.from_pretrained(
                pretrained, trust_remote_code=trust_remote_code, revision=revision
            )
baberabb's avatar
nits  
baberabb committed
114
115
116
117
118
119
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
        )
120
        self.add_bos_token = add_bos_token
121

baberabb's avatar
baberabb committed
122
123
124
125
126
127
128
129
130
131
132
        self._max_gen_toks = max_gen_toks

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
133
134
135
136
137
138
139
140
141
142
143
144
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
145
146
147
148
149

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

baberabb's avatar
baberabb committed
150
151
152
153
    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
154
        add_special_tokens=None,
baberabb's avatar
baberabb committed
155
156
        truncation=False,
    ):
baberabb's avatar
baberabb committed
157
        """ """
158
159
        if not add_special_tokens:
            add_special_tokens = False or self.add_bos_token
baberabb's avatar
baberabb committed
160
161
162
        encoding = self.tokenizer.encode(
            string, add_special_tokens=add_special_tokens, truncation=truncation
        )
baberabb's avatar
baberabb committed
163
164
165
166
167
168
169
170
171

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
172
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
173
174
175
176
177
178
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
179
            kwargs = self.modify_gen_kwargs(kwargs)
baberabb's avatar
baberabb committed
180
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
181
        else:
baberabb's avatar
baberabb committed
182
            sampling_params = SamplingParams(
183
                temperature=0, prompt_logprobs=1, max_tokens=1
baberabb's avatar
baberabb committed
184
            )
185
        if self.data_parallel_size > 1:
Baber Abbasi's avatar
Baber Abbasi committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
            # note: this has changed on 0.3.3, and it only works now if num_gpus are set.
            # but then tensor_parallel breaks
            @ray.remote
            def run_inference_one_model(
                model_args: dict, sampling_params, requests: List[List[int]]
            ):
                llm = LLM(**model_args)
                return llm.generate(
                    prompt_token_ids=requests, sampling_params=sampling_params
                )

200
201
202
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
Baber Abbasi's avatar
Baber Abbasi committed
203
204
205
            inputs = ((self.model_args, sampling_params, req) for req in requests)
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
206
207
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
208
            # flatten results
209
            return undistribute(results)
baberabb's avatar
baberabb committed
210
211
212
213

        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
214
            use_tqdm=True if self.batch_size == "auto" else False,
baberabb's avatar
baberabb committed
215
        )
baberabb's avatar
baberabb committed
216
217
        return outputs

baberabb's avatar
baberabb committed
218
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
baberabb's avatar
baberabb committed
219
220
221
222
223
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
224
225
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
226
227
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
baberabb's avatar
baberabb committed
228
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods

    def generate_until(self, requests: List[Instance]) -> List[str]:
248
        res = []
baberabb's avatar
baberabb committed
249
250
251

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
252
        context_encoding = self.tokenizer(context, add_special_tokens=False).input_ids
baberabb's avatar
baberabb committed
253
254
255
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
256
257
258
259
260
261
262
263

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
264
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
265
266
267
268

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
269
        re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
270
271
272
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
273

274
275
276
277
278
        pbar = tqdm(
            total=len(requests),
            disable=(self.rank != 0),
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
279
        # for each different set of kwargs, we execute all requests, by batch.
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [until]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
baberabb's avatar
baberabb committed
301
                )
302
            # add EOS token to stop sequences
Baber Abbasi's avatar
Baber Abbasi committed
303
            eos = self.tokenizer.decode(self.eot_token_id)
304
            if not until:
305
306
307
                until = [eos]
            else:
                until.append(eos)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
326

327
328
329
330
331
332
333
334
            # cache generations
            for output, context in zip(cont, context):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
335
336

        pbar.close()
337
338
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
339
340

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
341
342
343
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
344
345
346
347
348
349
350
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

351
352
353
354
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
355
        )
356

357
358
359
360
361
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
362
        for chunk in chunks:
363
            inputs = []
baberabb's avatar
baberabb committed
364
365
366
367
368
369
370
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

371
                inputs.append(inp)
baberabb's avatar
baberabb committed
372
373
                ctxlens.append(ctxlen)

374
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
375

376
377
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
378
379
            ):
                answer = self._parse_logprobs(
380
381
382
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
383
384
385
386
387
388
389
                )

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
390
                pbar.update(1)
baberabb's avatar
baberabb committed
391
392
393
394
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
395
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
396
397
398
        """Process logprobs and tokens.

        :param tokens: list
399
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
400
        :param outputs: RequestOutput
401
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
402
403
404
405
406
407
408
409
410
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

411
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
412
413
        continuation_logprobs_dicts = outputs.prompt_logprobs

baberabb's avatar
baberabb committed
414
        # Calculate continuation_logprobs
415
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
416
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
417
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
418
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
419
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
420
421
422
423
424
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
425
426
427
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
428
429
430
431
432
433
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
434
435

        return continuation_logprobs, is_greedy
436
437
438
439

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
440
441
        do_sample = kwargs.pop("do_sample", None)
        if do_sample is False or "temperature" not in kwargs:
442
443
444
445
446
447
448
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs