vllm_causallms.py 30.8 KB
Newer Older
1
import copy
2
import gc
Lintang Sutawika's avatar
Lintang Sutawika committed
3
import logging
4
import os
Baber Abbasi's avatar
Baber Abbasi committed
5
from importlib.metadata import version
6
from importlib.util import find_spec
7
8
9
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
10
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
11

12
import jinja2
13
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
14
from packaging.version import parse as parse_version
15
16
from tqdm import tqdm

baberabb's avatar
baberabb committed
17
from lm_eval.api.instance import Instance
18
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
19
from lm_eval.api.registry import register_model
20
21
22
23
from lm_eval.models.utils import (
    Collator,
    configure_pad_token,
    handle_stop_sequences,
24
    postprocess_generated_text,
25
26
    undistribute,
)
27
28
29
30
from lm_eval.utils import (
    get_rolling_token_windows,
    make_disjoint_window,
)
31

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
32

33
try:
34
    import ray
35
    from vllm import LLM, SamplingParams, TokensPrompt
36
    from vllm.lora.request import LoRARequest
baberabb's avatar
baberabb committed
37
    from vllm.transformers_utils.tokenizer import get_tokenizer
38
    from vllm.utils import get_open_port
39
40
41

    if parse_version(version("vllm")) >= parse_version("0.8.3"):
        from vllm.entrypoints.chat_utils import resolve_hf_chat_template
42
43
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
44

45
46
if TYPE_CHECKING:
    pass
bcicc's avatar
bcicc committed
47

Lintang Sutawika's avatar
Lintang Sutawika committed
48
eval_logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
49

baberabb's avatar
baberabb committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def _vllm_mp_worker(
    model_args: dict,
    sampling_params: "SamplingParams",
    requests: list[list[int]],
    lora_request: "LoRARequest",
    result_queue: "Queue",
    dp_size: int,
    local_dp_rank: int,
    dp_master_port: int,
    dp_master_ip: str = "127.0.0.1",
) -> None:
    """
    Worker process for vLLM multiprocessing.
    Initializes a vLLM engine, processes requests, and puts results or errors
    onto the result_queue.
    """

    if not requests:
        result_queue.put((local_dp_rank, []))
        return None

    os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)

    llm = None
    try:
        llm = LLM(**model_args)
        res = llm.generate(
81
            [TokensPrompt(prompt_token_ids=request) for request in requests],
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            sampling_params=sampling_params,
            lora_request=lora_request,
        )
        # Give engines time to pause their processing loops before exiting."
        sleep(1)
        result_queue.put((local_dp_rank, res))

    except Exception as e:
        error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
        eval_logger.error(error_message, exc_info=True)
        result_queue.put((local_dp_rank, {"error": error_message}))

    finally:
        if llm is not None:
            try:
                del llm
                gc.collect()
            except Exception as e_cleanup:
                eval_logger.warning(
                    f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
                    exc_info=True,
                )

    return None


baberabb's avatar
baberabb committed
108
@register_model("vllm")
109
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
110
111
112
113
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
114
        pretrained: str,
baberabb's avatar
baberabb committed
115
116
117
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
118
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
119
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
120
        tokenizer_revision: Optional[str] = None,
121
        add_bos_token: Optional[bool] = False,
122
        prefix_token_id: Optional[int] = None,
baberabb's avatar
baberabb committed
123
        tensor_parallel_size: int = 1,
124
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
125
126
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
127
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
128
        max_batch_size=None,
baberabb's avatar
baberabb committed
129
        max_length: int = None,
130
        max_model_len: int = None,
baberabb's avatar
baberabb committed
131
        seed: int = 1234,
132
        gpu_memory_utilization: float = 0.9,
133
        data_parallel_size: int = 1,
bcicc's avatar
bcicc committed
134
        lora_local_path: str = None,
135
136
        # VLLM: enable thinking tags in the prompt.
        enable_thinking: bool = True,
137
        chat_template_args: Optional[dict] = None,
138
139
        # End marker for thinking tags - splits to get response after this token (if provided).
        think_end_token: Optional[str] = None,
MaYongQing's avatar
MaYongQing committed
140
        max_lora_rank: int = 16,
Baber Abbasi's avatar
Baber Abbasi committed
141
        **kwargs,
baberabb's avatar
baberabb committed
142
143
    ):
        super().__init__()
144

145
        if not find_spec("vllm"):
146
            raise ModuleNotFoundError(
147
148
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
149
150
            )

Baber Abbasi's avatar
Baber Abbasi committed
151
152
153
        assert max_length is None or max_model_len is None, (
            "Either max_length or max_model_len may be provided, but not both"
        )
Baber Abbasi's avatar
Baber Abbasi committed
154
        kwargs.pop("device", None)
155
        self.think_end_token = think_end_token
156
        self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
157
        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
158
        self.tensor_parallel_size = int(tensor_parallel_size)
159
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
160
161
162
163
164
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
165
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
166
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
167
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
168
169
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
170
            "max_model_len": int(self._max_length) if self._max_length else None,
171
            "max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
baberabb's avatar
baberabb committed
172
173
174
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
MaYongQing's avatar
MaYongQing committed
175
176
            "enable_lora": True if lora_local_path else False,
            "max_lora_rank": int(max_lora_rank),
baberabb's avatar
baberabb committed
177
        }
Baber Abbasi's avatar
Baber Abbasi committed
178
        self.model_args.update(kwargs)
179
180
181
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
182
            else int(batch_size)
183
        )
184
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
185
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
186
        else:
Baber Abbasi's avatar
Baber Abbasi committed
187
188
189
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
190
191
192
193
194
            self.model_args["distributed_executor_backend"] = (
                "ray"
                if not self.V1
                else self.model_args.get("distributed_executor_backend", None)
            )
195
196
197
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

198
199
200
201
202
203
        if "gemma" in pretrained.lower():
            add_bos_token = True
            eval_logger.info(
                "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
            )

204
        from transformers import AutoConfig
205

206
207
208
        self._config = AutoConfig.from_pretrained(
            pretrained, trust_remote_code=trust_remote_code, revision=revision
        )
baberabb's avatar
nits  
baberabb committed
209
210
211
212
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
213
            revision=tokenizer_revision,
214
            add_bos_token=add_bos_token,
baberabb's avatar
nits  
baberabb committed
215
        )
216
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
217
        self.chat_template_args = chat_template_args or {}
218
        self.enable_thinking = self.chat_template_args.pop(
219
220
            "enable_thinking", enable_thinking
        )
221
        self.add_bos_token = add_bos_token
222

223
        if parse_version(version("vllm")) >= parse_version("0.8.3"):
224
225
226
227
228
229
230
            kwargs_resolve_hf_chat_template = {
                "tokenizer": self.tokenizer,
                "chat_template": None,
                "tools": None,
            }

            if parse_version(version("vllm")) >= parse_version("0.9.0"):
231
232
233
234
235
236
237
238
239
240
241
                if self.data_parallel_size <= 1:
                    kwargs_resolve_hf_chat_template["model_config"] = (
                        self.model.llm_engine.model_config
                    )
                else:
                    from vllm.engine.arg_utils import EngineArgs

                    engine_args = EngineArgs(**self.model_args)
                    model_config = engine_args.create_model_config()

                    kwargs_resolve_hf_chat_template["model_config"] = model_config
242
243
244
            else:
                kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code

245
            self.hf_chat_template = resolve_hf_chat_template(
246
                **kwargs_resolve_hf_chat_template
247
248
249
            )
        else:
            self.hf_chat_template = None
250

251
252
253
254
255
        self.custom_prefix_token_id = prefix_token_id
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
256

baberabb's avatar
baberabb committed
257
258
        self._max_gen_toks = max_gen_toks

bcicc's avatar
bcicc committed
259
        if lora_local_path is not None:
Baber Abbasi's avatar
Baber Abbasi committed
260
261
262
            assert parse_version(version("vllm")) > parse_version("0.3.0"), (
                "lora adapters only compatible with vllm > v0.3.0."
            )
bcicc's avatar
bcicc committed
263
264
265
266
            self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
        else:
            self.lora_request = None

baberabb's avatar
baberabb committed
267
268
269
270
271
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

272
273
274
275
276
277
278
279
280
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

baberabb's avatar
baberabb committed
281
282
283
284
    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
285
286
287
288
289
290
291
292
293
294
295
296
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
297
298
299
300
301

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

Baber Abbasi's avatar
Baber Abbasi committed
302
303
304
    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
    ) -> str:
305
306
307
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
308
309
310
311
312
313
314
315
        try:
            chat_templated = self.tokenizer.apply_chat_template(
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
316
                **self.chat_template_args,
317
318
319
320
321
322
323
324
325
326
327
328
            )
        except jinja2.exceptions.TemplateError:
            eval_logger.warning(
                "Failed to apply chat template. removing the system role in chat history."
            )
            chat_templated = self.tokenizer.apply_chat_template(
                [msg for msg in chat_history if msg["role"] != "system"],
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
329
                **self.chat_template_args,
330
            )
331

Baber Abbasi's avatar
Baber Abbasi committed
332
333
        return chat_templated

334
335
336
337
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

baberabb's avatar
baberabb committed
338
339
    def tok_encode(
        self,
340
341
342
343
344
        string: Union[str, List[str]],
        left_truncate_len: int = None,
        add_special_tokens: bool = False,
        truncation: bool = False,
    ) -> Union[List[int], List[List[int]]]:
345
346
        if not add_special_tokens:
            add_special_tokens = False or self.add_bos_token
347
348
349
350
351
352
        encoding: Union[List[List[int]], List[int]] = self.tokenizer(
            string,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            return_attention_mask=False,
        ).input_ids
baberabb's avatar
baberabb committed
353
354
355

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
356
357
358
359
            if not isinstance(string, str):
                encoding = [enc[-left_truncate_len:] for enc in encoding]
            else:
                encoding = encoding[-left_truncate_len:]
baberabb's avatar
baberabb committed
360
361
362
363
364

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
365
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
366
367
368
369
370
371
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
372
            kwargs = self.modify_gen_kwargs(kwargs)
baberabb's avatar
baberabb committed
373
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
374
        else:
baberabb's avatar
baberabb committed
375
            sampling_params = SamplingParams(
376
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
baberabb's avatar
baberabb committed
377
            )
378
        if self.data_parallel_size > 1 and not self.V1:
Baber Abbasi's avatar
Baber Abbasi committed
379
            # vLLM hangs if resources are set in ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
380
381
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
382
            @ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
383
            def run_inference_one_model(
384
                model_args: dict,
Baber Abbasi's avatar
Baber Abbasi committed
385
                sampling_params: SamplingParams,
386
387
                requests: List[List[int]],
                lora_request: LoRARequest,
Baber Abbasi's avatar
Baber Abbasi committed
388
389
390
            ):
                llm = LLM(**model_args)
                return llm.generate(
391
                    [TokensPrompt(prompt_token_ids=request) for request in requests],
392
393
                    sampling_params=sampling_params,
                    lora_request=lora_request,
Baber Abbasi's avatar
Baber Abbasi committed
394
395
                )

396
397
398
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
399
400
401
402
            inputs = (
                (self.model_args, sampling_params, req, self.lora_request)
                for req in requests
            )
Baber Abbasi's avatar
Baber Abbasi committed
403
404
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
405
406
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
407
            # flatten results
408
            return undistribute(results)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        elif self.data_parallel_size > 1:
            # based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
            dp_size = self.data_parallel_size
            dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
            dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()

            requests = (list(x) for x in distribute(self.data_parallel_size, requests))

            procs, resq = [], Queue()
            # We use Process as it is non-daemonic
            try:
                for rank, req in enumerate(requests):
                    proc = Process(
                        target=_vllm_mp_worker,
                        args=(
                            self.model_args.copy(),
                            sampling_params,
                            req,
                            self.lora_request,
                            resq,
                            dp_size,
                            rank,
                            dp_master_port,
                            dp_master_ip,
                        ),
                    )
                    proc.start()
                    procs.append(proc)

                # Collect results
                rank_res = {}
                while len(rank_res) < len(procs):
                    try:
                        rank, result = resq.get(timeout=30)
                        if isinstance(result, dict) and "error" in result:
                            raise RuntimeError(result["error"])
                        rank_res[rank] = result
                    except Empty:
                        dead_procs = [
                            idx
                            for idx, p in enumerate(procs)
                            if not p.is_alive() and idx not in rank_res
                        ]
                        if dead_procs:
                            raise RuntimeError(
                                f"Worker processes {dead_procs} died unexpectedly"
                            )
                        continue

                results = [rank_res[i] for i in range(len(procs))]
                return undistribute(results)

            # cleanup
            finally:
                try:
                    resq.close()
                    resq.join_thread()
                except Exception:
                    eval_logger.debug(
                        "Failed to close vllm DP results queue", exc_info=True
                    )
                for proc in procs:
                    proc.join(timeout=10)
                    if proc.is_alive():
                        proc.terminate()
                        proc.join(timeout=5)
                        if proc.is_alive():
                            proc.kill()
baberabb's avatar
baberabb committed
477

478
479
        else:
            outputs = self.model.generate(
480
                [TokensPrompt(prompt_token_ids=request) for request in requests],
481
482
483
484
485
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )
            return outputs
baberabb's avatar
baberabb committed
486

487
488
489
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        adaptive_batch_size = None
        if self.batch_size == "auto":
            adaptive_batch_size = len(requests)

        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
        ):
            rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
baberabb's avatar
baberabb committed
505
                map(
506
507
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
508
                        token_list=self.tok_encode(string),
509
510
                        prefix_token=self.prefix_token_id,
                        # max_seq_len - (1 for context)
baberabb's avatar
baberabb committed
511
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
512
513
514
515
516
                        context_len=1,
                    ),
                )
            )

517
518
            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            windows = [(None,) + x for x in rolling_token_windows]
baberabb's avatar
baberabb committed
519

520
521
522
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
baberabb's avatar
baberabb committed
523

524
525
526
527
528
529
        all_nlls = []
        batch_size = adaptive_batch_size or int(self.batch_size)
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)
baberabb's avatar
baberabb committed
530

531
532
533
534
535
536
            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
            )
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
537

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
553

baberabb's avatar
baberabb committed
554
555
        return loglikelihoods

556
557
558
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
559
        res = []
baberabb's avatar
baberabb committed
560
561
562

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
563
564
565
        context_encoding: List[List[int]] = self.tok_encode(
            context, add_special_tokens=self.add_bos_token
        )
baberabb's avatar
baberabb committed
566
567
568
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
569
570
571
572
573
574
575
576

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
577
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
578
579
580
581

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
582
        re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
583
584
585
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
586

587
588
        pbar = tqdm(
            total=len(requests),
589
            disable=(disable_tqdm or (self.rank != 0)),
590
591
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
592
        # for each different set of kwargs, we execute all requests, by batch.
593
        eos = self.tokenizer.decode(self.eot_token_id)
594
595
596
597
598
599
600
601
602
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
603
604
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
605
606
            else:
                raise ValueError(
607
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
baberabb's avatar
baberabb committed
608
                )
609
610
611
612
613
614
615
616
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
617
618
619
620
621
622
            all_lengths = [len(x) for x in context_encoding]
            for length in all_lengths:
                if length > max_ctx_len:
                    eval_logger.warning(
                        f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
                    )
623
624
625
626
627
628
629
630
631
632
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
633

634
635
            # cache generations
            for output, context in zip(cont, context):
636
                generated_text: str = output.outputs[0].text
637
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
638
639
640
                generated_text = postprocess_generated_text(
                    generated_text, until, self.think_end_token
                )
641
642
643
644
645
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
646
647

        pbar.close()
648
649
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
650
651

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
652
653
654
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
655
656
657
658
659
660
661
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

662
663
664
665
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
666
        )
667

668
669
670
671
672
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
673
        for chunk in chunks:
674
            inputs = []
baberabb's avatar
baberabb committed
675
676
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
677
678
                if (
                    full_length := len(context_enc + continuation_enc)
679
                ) > self.max_length:
680
681
682
                    eval_logger.warning(
                        f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
                    )
baberabb's avatar
baberabb committed
683
684
685
686
687
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

688
                inputs.append(inp)
baberabb's avatar
baberabb committed
689
690
                ctxlens.append(ctxlen)

691
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
692

693
694
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
695
696
            ):
                answer = self._parse_logprobs(
697
698
699
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
700
701
702
703
704
                )

                res.append(answer)

                if cache_key is not None:
705
706
707
                    # special case: loglikelihood_rolling produces a number of loglikelihood requests
                    # all with cache key None. instead do add_partial on the per-example level
                    # in the loglikelihood_rolling() function for those.
baberabb's avatar
baberabb committed
708
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
709
                pbar.update(1)
baberabb's avatar
baberabb committed
710
711
712
713
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
714
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
715
716
717
        """Process logprobs and tokens.

        :param tokens: list
718
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
719
        :param outputs: RequestOutput
720
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
721
722
723
724
725
726
727
728
729
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

730
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
731
732
        continuation_logprobs_dicts = outputs.prompt_logprobs

733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        def coerce_logprob_to_num(logprob):
            # vLLM changed the return type of logprobs from float
            # to a Logprob object storing the float value + extra data
            # (https://github.com/vllm-project/vllm/pull/3065).
            # If we are dealing with vllm's Logprob object, return
            # the logprob value stored as an attribute. Otherwise,
            # return the object itself (which should be a float
            # for older versions of vLLM).
            return getattr(logprob, "logprob", logprob)

        continuation_logprobs_dicts = [
            {
                token: coerce_logprob_to_num(logprob)
                for token, logprob in logprob_dict.items()
            }
            if logprob_dict is not None
            else None
            for logprob_dict in continuation_logprobs_dicts
        ]

baberabb's avatar
baberabb committed
753
        # Calculate continuation_logprobs
754
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
755
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
756
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
757
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
758
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
759
760
761
762
763
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
764
765
766
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
767
768
769
770
771
772
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
773
774

        return continuation_logprobs, is_greedy
775
776
777
778

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
779
        kwargs["temperature"] = kwargs.get("temperature", 0.0)
780
        do_sample = kwargs.pop("do_sample", None)
781
782
783
784
        if do_sample is False and "temperature" not in kwargs:
            eval_logger.debug(
                "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
            )
785
786
787
788
789
790
791
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs