vllm_causallms.py 24.1 KB
Newer Older
1
import copy
2
import inspect
Lintang Sutawika's avatar
Lintang Sutawika committed
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
from importlib.metadata import version
5
from importlib.util import find_spec
6
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
7

8
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
9
from packaging.version import parse as parse_version
10
11
from tqdm import tqdm

baberabb's avatar
baberabb committed
12
from lm_eval.api.instance import Instance
13
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
14
from lm_eval.api.registry import register_model
15
16
17
18
19
20
from lm_eval.models.utils import (
    Collator,
    configure_pad_token,
    handle_stop_sequences,
    undistribute,
)
21
22
23
24
from lm_eval.utils import (
    get_rolling_token_windows,
    make_disjoint_window,
)
25

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
26

27
try:
28
    import ray
29
    from vllm import LLM, SamplingParams
30
    from vllm.lora.request import LoRARequest
baberabb's avatar
baberabb committed
31
    from vllm.transformers_utils.tokenizer import get_tokenizer
32
33
34

    if parse_version(version("vllm")) >= parse_version("0.8.3"):
        from vllm.entrypoints.chat_utils import resolve_hf_chat_template
35
36
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
37

38
39
if TYPE_CHECKING:
    pass
bcicc's avatar
bcicc committed
40

Lintang Sutawika's avatar
Lintang Sutawika committed
41
eval_logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
42

baberabb's avatar
baberabb committed
43
44

@register_model("vllm")
45
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
46
47
48
49
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
50
        pretrained: str,
baberabb's avatar
baberabb committed
51
52
53
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
54
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
55
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
56
        tokenizer_revision: Optional[str] = None,
57
        add_bos_token: Optional[bool] = False,
58
        prefix_token_id: Optional[int] = None,
baberabb's avatar
baberabb committed
59
        tensor_parallel_size: int = 1,
60
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
61
62
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
63
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
64
        max_batch_size=None,
baberabb's avatar
baberabb committed
65
        max_length: int = None,
66
        max_model_len: int = None,
baberabb's avatar
baberabb committed
67
        seed: int = 1234,
68
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
69
        device: str = "cuda",
70
        data_parallel_size: int = 1,
bcicc's avatar
bcicc committed
71
        lora_local_path: str = None,
72
        enable_thinking: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
73
        **kwargs,
baberabb's avatar
baberabb committed
74
75
    ):
        super().__init__()
76

77
        if not find_spec("vllm"):
78
            raise ModuleNotFoundError(
79
80
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
81
82
            )

Baber Abbasi's avatar
Baber Abbasi committed
83
84
85
        assert max_length is None or max_model_len is None, (
            "Either max_length or max_model_len may be provided, but not both"
        )
86
87

        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
88
        self.tensor_parallel_size = int(tensor_parallel_size)
89
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
90
91
92
93
94
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
95
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
96
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
97
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
98
99
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
100
            "max_model_len": int(self._max_length) if self._max_length else None,
baberabb's avatar
baberabb committed
101
102
103
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
104
            "device": str(device),
baberabb's avatar
baberabb committed
105
        }
Baber Abbasi's avatar
Baber Abbasi committed
106
        self.model_args.update(kwargs)
107
108
109
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
110
            else int(batch_size)
111
        )
112
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
113
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
114
        else:
Baber Abbasi's avatar
Baber Abbasi committed
115
116
117
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
Baber Abbasi's avatar
Baber Abbasi committed
118
            self.model_args["distributed_executor_backend"] = "ray"
119
120
121
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

122
        from transformers import AutoConfig
123

124
125
126
        self._config = AutoConfig.from_pretrained(
            pretrained, trust_remote_code=trust_remote_code, revision=revision
        )
baberabb's avatar
nits  
baberabb committed
127
128
129
130
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
131
            revision=tokenizer_revision,
132
            add_bos_token=add_bos_token,
baberabb's avatar
nits  
baberabb committed
133
        )
134
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
135
        self.enable_thinking = enable_thinking
136
        self.add_bos_token = add_bos_token
137
138
139
        if "gemma" in pretrained.lower():
            self.add_bos_token = True
            eval_logger.info(
140
                "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
141
142
            )

143
        if parse_version(version("vllm")) >= parse_version("0.8.3"):
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            kwargs_resolve_hf_chat_template = {
                "tokenizer": self.tokenizer,
                "chat_template": None,
                "tools": None,
            }

            if parse_version(version("vllm")) >= parse_version("0.9.0"):
                kwargs_resolve_hf_chat_template["model_config"] = (
                    self.model.llm_engine.model_config
                )

            # https://github.com/vllm-project/vllm/pull/18259
            if (
                "trsut_remote_code"
                in inspect.signature(resolve_hf_chat_template).parameters
            ):
                kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
            else:
                kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code

164
            self.hf_chat_template = resolve_hf_chat_template(
165
                **kwargs_resolve_hf_chat_template
166
167
168
            )
        else:
            self.hf_chat_template = None
169

170
171
172
173
174
        self.custom_prefix_token_id = prefix_token_id
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
175

baberabb's avatar
baberabb committed
176
177
        self._max_gen_toks = max_gen_toks

bcicc's avatar
bcicc committed
178
        if lora_local_path is not None:
Baber Abbasi's avatar
Baber Abbasi committed
179
180
181
            assert parse_version(version("vllm")) > parse_version("0.3.0"), (
                "lora adapters only compatible with vllm > v0.3.0."
            )
bcicc's avatar
bcicc committed
182
183
184
185
            self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
        else:
            self.lora_request = None

baberabb's avatar
baberabb committed
186
187
188
189
190
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

191
192
193
194
195
196
197
198
199
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

baberabb's avatar
baberabb committed
200
201
202
203
    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
204
205
206
207
208
209
210
211
212
213
214
215
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
216
217
218
219
220

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

Baber Abbasi's avatar
Baber Abbasi committed
221
222
223
    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
    ) -> str:
224
225
226
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
Baber Abbasi's avatar
Baber Abbasi committed
227
228
229
230
231
        chat_templated = self.tokenizer.apply_chat_template(
            chat_history,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=not add_generation_prompt,
232
            chat_template=self.hf_chat_template,
233
            enable_thinking=self.enable_thinking,
234
235
        )

Baber Abbasi's avatar
Baber Abbasi committed
236
237
        return chat_templated

238
239
240
241
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

baberabb's avatar
baberabb committed
242
243
    def tok_encode(
        self,
244
245
246
247
248
        string: Union[str, List[str]],
        left_truncate_len: int = None,
        add_special_tokens: bool = False,
        truncation: bool = False,
    ) -> Union[List[int], List[List[int]]]:
249
250
        if not add_special_tokens:
            add_special_tokens = False or self.add_bos_token
251
252
253
254
255
256
        encoding: Union[List[List[int]], List[int]] = self.tokenizer(
            string,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            return_attention_mask=False,
        ).input_ids
baberabb's avatar
baberabb committed
257
258
259

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
260
261
262
263
            if not isinstance(string, str):
                encoding = [enc[-left_truncate_len:] for enc in encoding]
            else:
                encoding = encoding[-left_truncate_len:]
baberabb's avatar
baberabb committed
264
265
266
267
268

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
269
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
270
271
272
273
274
275
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
276
            kwargs = self.modify_gen_kwargs(kwargs)
baberabb's avatar
baberabb committed
277
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
278
        else:
baberabb's avatar
baberabb committed
279
            sampling_params = SamplingParams(
280
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
baberabb's avatar
baberabb committed
281
            )
282
        if self.data_parallel_size > 1:
Baber Abbasi's avatar
Baber Abbasi committed
283
            # vLLM hangs if resources are set in ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
284
285
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
286
            @ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
287
            def run_inference_one_model(
288
                model_args: dict,
Baber Abbasi's avatar
Baber Abbasi committed
289
                sampling_params: SamplingParams,
290
291
                requests: List[List[int]],
                lora_request: LoRARequest,
Baber Abbasi's avatar
Baber Abbasi committed
292
293
294
            ):
                llm = LLM(**model_args)
                return llm.generate(
295
296
297
                    prompt_token_ids=requests,
                    sampling_params=sampling_params,
                    lora_request=lora_request,
Baber Abbasi's avatar
Baber Abbasi committed
298
299
                )

300
301
302
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
303
304
305
306
            inputs = (
                (self.model_args, sampling_params, req, self.lora_request)
                for req in requests
            )
Baber Abbasi's avatar
Baber Abbasi committed
307
308
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
309
310
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
311
            # flatten results
312
            return undistribute(results)
baberabb's avatar
baberabb committed
313

314
315
316
317
318
319
        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
            use_tqdm=True if self.batch_size == "auto" else False,
            lora_request=self.lora_request,
        )
baberabb's avatar
baberabb committed
320
321
        return outputs

322
323
324
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        adaptive_batch_size = None
        if self.batch_size == "auto":
            adaptive_batch_size = len(requests)

        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
        ):
            rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
baberabb's avatar
baberabb committed
340
                map(
341
342
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
343
                        token_list=self.tok_encode(string),
344
345
                        prefix_token=self.prefix_token_id,
                        # max_seq_len - (1 for context)
baberabb's avatar
baberabb committed
346
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
347
348
349
350
351
                        context_len=1,
                    ),
                )
            )

352
353
            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            windows = [(None,) + x for x in rolling_token_windows]
baberabb's avatar
baberabb committed
354

355
356
357
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
baberabb's avatar
baberabb committed
358

359
360
361
362
363
364
        all_nlls = []
        batch_size = adaptive_batch_size or int(self.batch_size)
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)
baberabb's avatar
baberabb committed
365

366
367
368
369
370
371
            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
            )
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
388

baberabb's avatar
baberabb committed
389
390
        return loglikelihoods

391
392
393
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
394
        res = []
baberabb's avatar
baberabb committed
395
396
397

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
398
399
400
        context_encoding: List[List[int]] = self.tok_encode(
            context, add_special_tokens=self.add_bos_token
        )
baberabb's avatar
baberabb committed
401
402
403
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
404
405
406
407
408
409
410
411

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
412
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
413
414
415
416

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
417
        re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
418
419
420
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
421

422
423
        pbar = tqdm(
            total=len(requests),
424
            disable=(disable_tqdm or (self.rank != 0)),
425
426
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
427
        # for each different set of kwargs, we execute all requests, by batch.
428
        eos = self.tokenizer.decode(self.eot_token_id)
429
430
431
432
433
434
435
436
437
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
438
439
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
440
441
            else:
                raise ValueError(
442
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
baberabb's avatar
baberabb committed
443
                )
444
445
446
447
448
449
450
451
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
452
453
454
455
456
457
            all_lengths = [len(x) for x in context_encoding]
            for length in all_lengths:
                if length > max_ctx_len:
                    eval_logger.warning(
                        f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
                    )
458
459
460
461
462
463
464
465
466
467
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
468

469
470
471
472
473
474
475
476
            # cache generations
            for output, context in zip(cont, context):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
477
478

        pbar.close()
479
480
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
481
482

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
483
484
485
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
486
487
488
489
490
491
492
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

493
494
495
496
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
497
        )
498

499
500
501
502
503
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
504
        for chunk in chunks:
505
            inputs = []
baberabb's avatar
baberabb committed
506
507
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
508
509
510
511
                if (
                    full_length := len(context_enc + continuation_enc)
                    >= self.max_length
                ):
512
513
514
                    eval_logger.warning(
                        f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
                    )
baberabb's avatar
baberabb committed
515
516
517
518
519
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

520
                inputs.append(inp)
baberabb's avatar
baberabb committed
521
522
                ctxlens.append(ctxlen)

523
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
524

525
526
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
527
528
            ):
                answer = self._parse_logprobs(
529
530
531
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
532
533
534
535
536
                )

                res.append(answer)

                if cache_key is not None:
537
538
539
                    # special case: loglikelihood_rolling produces a number of loglikelihood requests
                    # all with cache key None. instead do add_partial on the per-example level
                    # in the loglikelihood_rolling() function for those.
baberabb's avatar
baberabb committed
540
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
541
                pbar.update(1)
baberabb's avatar
baberabb committed
542
543
544
545
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
546
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
547
548
549
        """Process logprobs and tokens.

        :param tokens: list
550
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
551
        :param outputs: RequestOutput
552
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
553
554
555
556
557
558
559
560
561
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

562
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
563
564
        continuation_logprobs_dicts = outputs.prompt_logprobs

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        def coerce_logprob_to_num(logprob):
            # vLLM changed the return type of logprobs from float
            # to a Logprob object storing the float value + extra data
            # (https://github.com/vllm-project/vllm/pull/3065).
            # If we are dealing with vllm's Logprob object, return
            # the logprob value stored as an attribute. Otherwise,
            # return the object itself (which should be a float
            # for older versions of vLLM).
            return getattr(logprob, "logprob", logprob)

        continuation_logprobs_dicts = [
            {
                token: coerce_logprob_to_num(logprob)
                for token, logprob in logprob_dict.items()
            }
            if logprob_dict is not None
            else None
            for logprob_dict in continuation_logprobs_dicts
        ]

baberabb's avatar
baberabb committed
585
        # Calculate continuation_logprobs
586
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
587
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
588
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
589
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
590
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
591
592
593
594
595
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
596
597
598
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
599
600
601
602
603
604
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
605
606

        return continuation_logprobs, is_greedy
607
608
609
610

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
611
        kwargs["temperature"] = kwargs.get("temperature", 0.0)
612
        do_sample = kwargs.pop("do_sample", None)
613
614
615
616
        if do_sample is False and "temperature" not in kwargs:
            eval_logger.debug(
                "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
            )
617
618
619
620
621
622
623
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs