vllm_causallms.py 14.6 KB
Newer Older
baberabb's avatar
baberabb committed
1
from collections import defaultdict
baberabb's avatar
baberabb committed
2
from typing import List, Tuple, Optional, Literal, Union
baberabb's avatar
baberabb committed
3
from transformers import AutoTokenizer
baberabb's avatar
baberabb committed
4
5
6
7
8
9
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
baberabb's avatar
baberabb committed
10
11
from ray.util.multiprocessing import Pool

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
12

13
14
15
16
try:
    from vllm import LLM, SamplingParams
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
17
18


baberabb's avatar
baberabb committed
19
20
eval_logger = utils.eval_logger

baberabb's avatar
baberabb committed
21

baberabb's avatar
baberabb committed
22
23
24
25
26
def run_inference_one_gpu(model_args: dict, sampling_params, requests: List[int]):
    llm = LLM(**model_args)
    return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)


baberabb's avatar
baberabb committed
27
28
29
30
31
32
33
34
35
36
37
38
@register_model("vllm")
class VLLM(LM):
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
        pretrained="gpt2",
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
        tokenizer_mode: Literal["auto", "slow"] = "auto",
        tensor_parallel_size: int = 1,
baberabb's avatar
baberabb committed
39
        quantization: Optional[Literal["awq"]] = None,
baberabb's avatar
baberabb committed
40
41
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
42
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
43
        max_batch_size=None,
baberabb's avatar
baberabb committed
44
        max_length: int = None,
baberabb's avatar
baberabb committed
45
        seed: int = 1234,
46
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
47
        device: str = "cuda",
baberabb's avatar
baberabb committed
48
        data_parallel: int = 1,
baberabb's avatar
baberabb committed
49
50
    ):
        super().__init__()
51
52
53

        try:
            import vllm
baberabb's avatar
baberabb committed
54
        except ModuleNotFoundError:
55
            raise Exception(
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
56
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. \
57
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`",
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
58
59
            )

baberabb's avatar
baberabb committed
60
        assert "cuda" in device or device is None, "vLLM only supports CUDA"
baberabb's avatar
baberabb committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        self.tensor_parallel_size = int(tensor_parallel_size)
        self.data_parallel = int(data_parallel)
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
            "tokenizer_mode": tokenizer_mode,
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
        }
        if self.data_parallel <= 1:
            self.model = LLM(**self.model_args)
        self.tokenizer = AutoTokenizer.from_pretrained(
            pretrained,
baberabb's avatar
baberabb committed
79
80
            revision=revision,
            trust_remote_code=trust_remote_code,
baberabb's avatar
baberabb committed
81
            use_fast=True if tokenizer_mode == "auto" else False,
baberabb's avatar
baberabb committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        )
        self.batch_size = batch_size
        self._max_length = max_length
        self._max_gen_toks = max_gen_toks

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
baberabb's avatar
baberabb committed
96
97
        if hasattr(self.tokenizer, "model_max_length"):
            return self.tokenizer.model_max_length
baberabb's avatar
baberabb committed
98
99
100
101
102
103
        return self._DEFAULT_MAX_LENGTH

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

baberabb's avatar
baberabb committed
104
105
106
107
108
109
110
    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=False,
        truncation=False,
    ):
baberabb's avatar
baberabb committed
111
        """ """
baberabb's avatar
baberabb committed
112
113
114
        encoding = self.tokenizer.encode(
            string, add_special_tokens=add_special_tokens, truncation=truncation
        )
baberabb's avatar
baberabb committed
115
116
117
118
119
120
121
122
123

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
124
        requests: List[int] = None,
baberabb's avatar
baberabb committed
125
126
127
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
baberabb's avatar
baberabb committed
128
        use_tqdm=True,
baberabb's avatar
baberabb committed
129
130
        **kwargs,
    ):
baberabb's avatar
bugfix  
baberabb committed
131
132
        if "do_sample" in kwargs.keys():
            kwargs.pop("do_sample")
baberabb's avatar
baberabb committed
133
        if generate:
baberabb's avatar
baberabb committed
134
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
135
        else:
baberabb's avatar
baberabb committed
136
            sampling_params = SamplingParams(
baberabb's avatar
baberabb committed
137
138
                temperature=0, prompt_logprobs=2, max_tokens=1
            )
baberabb's avatar
baberabb committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        if self.data_parallel > 1:
            req_list = []
            for replicas in range(self.data_parallel):
                reqs = utils.create_iterator(
                    requests, rank=replicas, world_size=self.data_parallel
                )
                req_list.append(reqs)
            inputs = [(self.model_args, sampling_params, req) for req in req_list]

            with Pool(processes=self.data_parallel) as pool:
                results = pool.starmap(run_inference_one_gpu, inputs)
            # flatten results
            return [item for sublist in results for item in sublist]

        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
            use_tqdm=use_tqdm,
        )

baberabb's avatar
baberabb committed
159
160
        return outputs

baberabb's avatar
baberabb committed
161
    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
162
163
164
165
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
baberabb's avatar
baberabb committed
166
167
168
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
baberabb's avatar
baberabb committed
169
170
171
172
173
174
175
176
177
178
179
180
            else:
                context_enc, continuation_enc = self.tokenizer(
                    [context, continuation],
                    truncation="do_not_truncate",
                    add_special_tokens=False,
                    return_attention_mask=False,
                ).input_ids

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

baberabb's avatar
baberabb committed
181
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
baberabb's avatar
baberabb committed
182
183
184
185
186
187
188
189
190
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
baberabb's avatar
baberabb committed
191
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods

    def generate_until(self, requests: List[Instance]) -> List[str]:
        res = defaultdict(list)
        re_ords = {}

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
baberabb's avatar
bugfix  
baberabb committed
216
        context_encoding = self.tokenizer(context).input_ids
baberabb's avatar
baberabb committed
217
218
219
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            return -len(_requests[0][1]), tuple(_requests[0][1])

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        grouper = utils.Grouper(requests, lambda x: str(x[1]))
        for key, reqs in grouper.get_grouped().items():
            # within each set of reqs for given kwargs, we reorder by token length, descending.
            re_ords[key] = utils.Reorderer(requests, _collate_gen)

        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        # for each different set of kwargs, we execute all requests, by batch.
        for key, re_ord in re_ords.items():
            chunks = utils.chunks(
                re_ord.get_reordered(),
baberabb's avatar
baberabb committed
243
                n=self.batch_size if self.batch_size != "auto" else 0,
baberabb's avatar
baberabb committed
244
245
246
                fn=None,
            )
            for chunk in chunks:
baberabb's avatar
bugfix  
baberabb committed
247
                context_and_encoding, all_gen_kwargs = zip(*chunk)
baberabb's avatar
baberabb committed
248
                context, context_encoding = zip(*context_and_encoding)
baberabb's avatar
baberabb committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [until]
                        elif not isinstance(until, list):
                            raise ValueError(
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                            )
                else:
                    raise ValueError(
                        f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
                    )
                if not until:
                    until = [self.tokenizer.decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks

                # set the max length in tokens of inputs ("context_enc")
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
                context_encoding = [x[-max_ctx_len:] for x in context_encoding]

                # TODO: max_length in kwargs

                # perform batched generation
                cont = self._model_generate(
                    requests=context_encoding,
                    generate=True,
                    max_tokens=max_gen_toks,
                    stop=until,
                    **kwargs,
                )

                # cache generations
                for output, context in zip(cont, context):
                    generated_text = output.outputs[0].text
                    res[key].append(generated_text)
                    self.cache_hook.add_partial(
                        "generate_until", (context, gen_kwargs), generated_text
                    )
                    pbar.update(1)

            # reorder this group of results back to original unsorted form
            res[key] = re_ord.get_original(res[key])

        pbar.close()

        return grouper.get_original(res)

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
308
309
310
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
311
312
313
314
315
316
317
318
319
320
321
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        re_ord = utils.Reorderer(requests, _collate)

        chunks = utils.chunks(
            re_ord.get_reordered(),
baberabb's avatar
baberabb committed
322
            n=self.batch_size if self.batch_size != "auto" else 0,
baberabb's avatar
baberabb committed
323
324
            fn=None,
        )
baberabb's avatar
baberabb committed
325
        pbar = tqdm(total=len(requests), disable=disable_tqdm)
baberabb's avatar
baberabb committed
326
327
328
329
330
331
332
333
334
335
336
337
        for chunk in chunks:
            inps = []
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

                inps.append(inp)
                ctxlens.append(ctxlen)

baberabb's avatar
baberabb committed
338
            outputs = self._model_generate(requests=inps, generate=False)
baberabb's avatar
baberabb committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

            for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                outputs, ctxlens, chunk
            ):
                answer = self._parse_logprobs(
                    (context_enc + continuation_enc),
                    output,
                    ctxlen,
                )

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
                    pbar.update(1)
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
359
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
360
361
362
        """Process logprobs and tokens.

        :param tokens: list
baberabb's avatar
baberabb committed
363
            Tokens from context+continuations
baberabb's avatar
bugfix  
baberabb committed
364
365
        :param outputs: RequestOutput
            Contains prompt
baberabb's avatar
baberabb committed
366
367
368
369
370
371
372
373
374
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

baberabb's avatar
baberabb committed
375
        # prompt_logprobs = [None, {}*len(context-1)]
baberabb's avatar
bugfix  
baberabb committed
376
377
        continuation_logprobs_dicts = outputs.prompt_logprobs

baberabb's avatar
baberabb committed
378
        # Calculate continuation_logprobs
baberabb's avatar
baberabb committed
379
        # assume ctxlen always > 1
baberabb's avatar
baberabb committed
380
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
381
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
382
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
383
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
384
385
386
387
388
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
389
390
391
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
392
393
394
395
396
397
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
398
399

        return continuation_logprobs, is_greedy