vllm_causallms.py 22 KB
Newer Older
1
import copy
Lintang Sutawika's avatar
Lintang Sutawika committed
2
import logging
Baber Abbasi's avatar
Baber Abbasi committed
3
from importlib.metadata import version
4
from importlib.util import find_spec
5
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
6

7
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
8
from packaging.version import parse as parse_version
9
10
from tqdm import tqdm

baberabb's avatar
baberabb committed
11
from lm_eval.api.instance import Instance
12
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
13
from lm_eval.api.registry import register_model
14
15
16
17
18
19
from lm_eval.models.utils import (
    Collator,
    configure_pad_token,
    handle_stop_sequences,
    undistribute,
)
20
21
22
23
from lm_eval.utils import (
    get_rolling_token_windows,
    make_disjoint_window,
)
24

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
25

26
try:
27
    import ray
28
    from vllm import LLM, SamplingParams
29
    from vllm.lora.request import LoRARequest
baberabb's avatar
baberabb committed
30
    from vllm.transformers_utils.tokenizer import get_tokenizer
31
32
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
33

34
35
if TYPE_CHECKING:
    pass
bcicc's avatar
bcicc committed
36

Lintang Sutawika's avatar
Lintang Sutawika committed
37
eval_logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
38

baberabb's avatar
baberabb committed
39
40

@register_model("vllm")
41
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
42
43
44
45
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
46
        pretrained: str,
baberabb's avatar
baberabb committed
47
48
49
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
50
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
51
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
52
        tokenizer_revision: Optional[str] = None,
53
        add_bos_token: Optional[bool] = False,
54
        prefix_token_id: Optional[int] = None,
baberabb's avatar
baberabb committed
55
        tensor_parallel_size: int = 1,
56
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
57
58
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
59
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
60
        max_batch_size=None,
baberabb's avatar
baberabb committed
61
        max_length: int = None,
62
        max_model_len: int = None,
baberabb's avatar
baberabb committed
63
        seed: int = 1234,
64
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
65
        device: str = "cuda",
66
        data_parallel_size: int = 1,
bcicc's avatar
bcicc committed
67
        lora_local_path: str = None,
Baber Abbasi's avatar
Baber Abbasi committed
68
        **kwargs,
baberabb's avatar
baberabb committed
69
70
    ):
        super().__init__()
71

72
        if not find_spec("vllm"):
73
            raise ModuleNotFoundError(
74
75
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
76
77
            )

Baber Abbasi's avatar
Baber Abbasi committed
78
79
80
        assert max_length is None or max_model_len is None, (
            "Either max_length or max_model_len may be provided, but not both"
        )
81
82

        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
83
        self.tensor_parallel_size = int(tensor_parallel_size)
84
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
85
86
87
88
89
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
90
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
91
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
92
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
93
94
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
95
            "max_model_len": int(self._max_length) if self._max_length else None,
baberabb's avatar
baberabb committed
96
97
98
99
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
        }
Baber Abbasi's avatar
Baber Abbasi committed
100
        self.model_args.update(kwargs)
101
102
103
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
104
            else int(batch_size)
105
        )
106
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
107
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
108
        else:
Baber Abbasi's avatar
Baber Abbasi committed
109
110
111
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
Baber Abbasi's avatar
Baber Abbasi committed
112
            self.model_args["distributed_executor_backend"] = "ray"
113
114
115
116
117
118
119
120
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

            from transformers import AutoConfig

            self._config = AutoConfig.from_pretrained(
                pretrained, trust_remote_code=trust_remote_code, revision=revision
            )
baberabb's avatar
nits  
baberabb committed
121
122
123
124
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
125
            revision=tokenizer_revision,
126
            add_bos_token=add_bos_token,
baberabb's avatar
nits  
baberabb committed
127
        )
128
        self.tokenizer = configure_pad_token(self.tokenizer)
129
        self.add_bos_token = add_bos_token
130
131
132
        if "gemma" in pretrained.lower():
            self.add_bos_token = True
            eval_logger.info(
133
                "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
134
135
            )

136
137
138
139
140
        self.custom_prefix_token_id = prefix_token_id
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
141

baberabb's avatar
baberabb committed
142
143
        self._max_gen_toks = max_gen_toks

bcicc's avatar
bcicc committed
144
        if lora_local_path is not None:
Baber Abbasi's avatar
Baber Abbasi committed
145
146
147
            assert parse_version(version("vllm")) > parse_version("0.3.0"), (
                "lora adapters only compatible with vllm > v0.3.0."
            )
bcicc's avatar
bcicc committed
148
149
150
151
            self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
        else:
            self.lora_request = None

baberabb's avatar
baberabb committed
152
153
154
155
156
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

157
158
159
160
161
162
163
164
165
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

baberabb's avatar
baberabb committed
166
167
168
169
    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
170
171
172
173
174
175
176
177
178
179
180
181
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
182
183
184
185
186

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

Baber Abbasi's avatar
Baber Abbasi committed
187
188
189
    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
    ) -> str:
190
191
192
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
Baber Abbasi's avatar
Baber Abbasi committed
193
194
195
196
197
        chat_templated = self.tokenizer.apply_chat_template(
            chat_history,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=not add_generation_prompt,
198
199
        )

Baber Abbasi's avatar
Baber Abbasi committed
200
201
        return chat_templated

202
203
204
205
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

baberabb's avatar
baberabb committed
206
207
    def tok_encode(
        self,
208
209
210
211
212
        string: Union[str, List[str]],
        left_truncate_len: int = None,
        add_special_tokens: bool = False,
        truncation: bool = False,
    ) -> Union[List[int], List[List[int]]]:
213
214
        if not add_special_tokens:
            add_special_tokens = False or self.add_bos_token
215
216
217
218
219
220
        encoding: Union[List[List[int]], List[int]] = self.tokenizer(
            string,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            return_attention_mask=False,
        ).input_ids
baberabb's avatar
baberabb committed
221
222
223

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
224
225
226
227
            if not isinstance(string, str):
                encoding = [enc[-left_truncate_len:] for enc in encoding]
            else:
                encoding = encoding[-left_truncate_len:]
baberabb's avatar
baberabb committed
228
229
230
231
232

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
233
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
234
235
236
237
238
239
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
240
            kwargs = self.modify_gen_kwargs(kwargs)
baberabb's avatar
baberabb committed
241
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
242
        else:
baberabb's avatar
baberabb committed
243
            sampling_params = SamplingParams(
244
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
baberabb's avatar
baberabb committed
245
            )
246
        if self.data_parallel_size > 1:
Baber Abbasi's avatar
Baber Abbasi committed
247
            # vLLM hangs if resources are set in ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
248
249
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
250
            @ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
251
            def run_inference_one_model(
252
                model_args: dict,
Baber Abbasi's avatar
Baber Abbasi committed
253
                sampling_params: SamplingParams,
254
255
                requests: List[List[int]],
                lora_request: LoRARequest,
Baber Abbasi's avatar
Baber Abbasi committed
256
257
258
            ):
                llm = LLM(**model_args)
                return llm.generate(
259
260
261
                    prompt_token_ids=requests,
                    sampling_params=sampling_params,
                    lora_request=lora_request,
Baber Abbasi's avatar
Baber Abbasi committed
262
263
                )

264
265
266
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
267
268
269
270
            inputs = (
                (self.model_args, sampling_params, req, self.lora_request)
                for req in requests
            )
Baber Abbasi's avatar
Baber Abbasi committed
271
272
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
273
274
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
275
            # flatten results
276
            return undistribute(results)
baberabb's avatar
baberabb committed
277

278
279
280
281
282
283
        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
            use_tqdm=True if self.batch_size == "auto" else False,
            lora_request=self.lora_request,
        )
baberabb's avatar
baberabb committed
284
285
        return outputs

286
287
288
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        adaptive_batch_size = None
        if self.batch_size == "auto":
            adaptive_batch_size = len(requests)

        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
        ):
            rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
baberabb's avatar
baberabb committed
304
                map(
305
306
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
307
                        token_list=self.tok_encode(string),
308
309
                        prefix_token=self.prefix_token_id,
                        # max_seq_len - (1 for context)
baberabb's avatar
baberabb committed
310
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
311
312
313
314
315
                        context_len=1,
                    ),
                )
            )

316
317
            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            windows = [(None,) + x for x in rolling_token_windows]
baberabb's avatar
baberabb committed
318

319
320
321
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
baberabb's avatar
baberabb committed
322

323
324
325
326
327
328
        all_nlls = []
        batch_size = adaptive_batch_size or int(self.batch_size)
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)
baberabb's avatar
baberabb committed
329

330
331
332
333
334
335
            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
            )
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
336

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
352

baberabb's avatar
baberabb committed
353
354
        return loglikelihoods

355
356
357
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
358
        res = []
baberabb's avatar
baberabb committed
359
360
361

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
362
363
364
        context_encoding: List[List[int]] = self.tok_encode(
            context, add_special_tokens=self.add_bos_token
        )
baberabb's avatar
baberabb committed
365
366
367
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
368
369
370
371
372
373
374
375

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
376
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
377
378
379
380

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
381
        re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
382
383
384
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
385

386
387
        pbar = tqdm(
            total=len(requests),
388
            disable=(disable_tqdm or (self.rank != 0)),
389
390
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
391
        # for each different set of kwargs, we execute all requests, by batch.
392
        eos = self.tokenizer.decode(self.eot_token_id)
393
394
395
396
397
398
399
400
401
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
402
403
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
404
405
            else:
                raise ValueError(
406
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
baberabb's avatar
baberabb committed
407
                )
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
426

427
428
429
430
431
432
433
434
            # cache generations
            for output, context in zip(cont, context):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
435
436

        pbar.close()
437
438
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
439
440

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
441
442
443
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
444
445
446
447
448
449
450
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

451
452
453
454
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
455
        )
456

457
458
459
460
461
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
462
        for chunk in chunks:
463
            inputs = []
baberabb's avatar
baberabb committed
464
465
466
467
468
469
470
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

471
                inputs.append(inp)
baberabb's avatar
baberabb committed
472
473
                ctxlens.append(ctxlen)

474
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
475

476
477
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
478
479
            ):
                answer = self._parse_logprobs(
480
481
482
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
483
484
485
486
487
                )

                res.append(answer)

                if cache_key is not None:
488
489
490
                    # special case: loglikelihood_rolling produces a number of loglikelihood requests
                    # all with cache key None. instead do add_partial on the per-example level
                    # in the loglikelihood_rolling() function for those.
baberabb's avatar
baberabb committed
491
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
492
                pbar.update(1)
baberabb's avatar
baberabb committed
493
494
495
496
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
497
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
498
499
500
        """Process logprobs and tokens.

        :param tokens: list
501
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
502
        :param outputs: RequestOutput
503
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
504
505
506
507
508
509
510
511
512
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

513
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
514
515
        continuation_logprobs_dicts = outputs.prompt_logprobs

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        def coerce_logprob_to_num(logprob):
            # vLLM changed the return type of logprobs from float
            # to a Logprob object storing the float value + extra data
            # (https://github.com/vllm-project/vllm/pull/3065).
            # If we are dealing with vllm's Logprob object, return
            # the logprob value stored as an attribute. Otherwise,
            # return the object itself (which should be a float
            # for older versions of vLLM).
            return getattr(logprob, "logprob", logprob)

        continuation_logprobs_dicts = [
            {
                token: coerce_logprob_to_num(logprob)
                for token, logprob in logprob_dict.items()
            }
            if logprob_dict is not None
            else None
            for logprob_dict in continuation_logprobs_dicts
        ]

baberabb's avatar
baberabb committed
536
        # Calculate continuation_logprobs
537
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
538
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
539
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
540
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
541
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
542
543
544
545
546
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
547
548
549
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
550
551
552
553
554
555
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
556
557

        return continuation_logprobs, is_greedy
558
559
560
561

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
562
        do_sample = kwargs.pop("do_sample", None)
563
564
565
566
        if do_sample is False and "temperature" not in kwargs:
            eval_logger.debug(
                "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
            )
567
568
569
570
571
572
573
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs