vllm_causallms.py 30.5 KB
Newer Older
1
import copy
2
import gc
3
import inspect
Lintang Sutawika's avatar
Lintang Sutawika committed
4
import logging
5
import os
Baber Abbasi's avatar
Baber Abbasi committed
6
from importlib.metadata import version
7
from importlib.util import find_spec
8
9
10
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
11
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
12

13
import jinja2
14
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
15
from packaging.version import parse as parse_version
16
17
from tqdm import tqdm

baberabb's avatar
baberabb committed
18
from lm_eval.api.instance import Instance
19
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
20
from lm_eval.api.registry import register_model
21
22
23
24
25
26
from lm_eval.models.utils import (
    Collator,
    configure_pad_token,
    handle_stop_sequences,
    undistribute,
)
27
28
29
30
from lm_eval.utils import (
    get_rolling_token_windows,
    make_disjoint_window,
)
31

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
32

33
try:
34
    import ray
35
    from vllm import LLM, SamplingParams
36
    from vllm.lora.request import LoRARequest
baberabb's avatar
baberabb committed
37
    from vllm.transformers_utils.tokenizer import get_tokenizer
38
    from vllm.utils import get_open_port
39
40
41

    if parse_version(version("vllm")) >= parse_version("0.8.3"):
        from vllm.entrypoints.chat_utils import resolve_hf_chat_template
42
43
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
44

45
46
if TYPE_CHECKING:
    pass
bcicc's avatar
bcicc committed
47

Lintang Sutawika's avatar
Lintang Sutawika committed
48
eval_logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
49

baberabb's avatar
baberabb committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def _vllm_mp_worker(
    model_args: dict,
    sampling_params: "SamplingParams",
    requests: list[list[int]],
    lora_request: "LoRARequest",
    result_queue: "Queue",
    dp_size: int,
    local_dp_rank: int,
    dp_master_port: int,
    dp_master_ip: str = "127.0.0.1",
) -> None:
    """
    Worker process for vLLM multiprocessing.
    Initializes a vLLM engine, processes requests, and puts results or errors
    onto the result_queue.
    """

    if not requests:
        result_queue.put((local_dp_rank, []))
        return None

    os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)

    llm = None
    try:
        llm = LLM(**model_args)
        res = llm.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
            lora_request=lora_request,
        )
        # Give engines time to pause their processing loops before exiting."
        sleep(1)
        result_queue.put((local_dp_rank, res))

    except Exception as e:
        error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
        eval_logger.error(error_message, exc_info=True)
        result_queue.put((local_dp_rank, {"error": error_message}))

    finally:
        if llm is not None:
            try:
                del llm
                gc.collect()
            except Exception as e_cleanup:
                eval_logger.warning(
                    f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
                    exc_info=True,
                )

    return None


baberabb's avatar
baberabb committed
108
@register_model("vllm")
109
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
110
111
112
113
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
114
        pretrained: str,
baberabb's avatar
baberabb committed
115
116
117
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
118
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
119
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
120
        tokenizer_revision: Optional[str] = None,
121
        add_bos_token: Optional[bool] = False,
122
        prefix_token_id: Optional[int] = None,
baberabb's avatar
baberabb committed
123
        tensor_parallel_size: int = 1,
124
        quantization: Optional[str] = None,
baberabb's avatar
baberabb committed
125
126
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
127
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
128
        max_batch_size=None,
baberabb's avatar
baberabb committed
129
        max_length: int = None,
130
        max_model_len: int = None,
baberabb's avatar
baberabb committed
131
        seed: int = 1234,
132
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
133
        device: str = "cuda",
134
        data_parallel_size: int = 1,
bcicc's avatar
bcicc committed
135
        lora_local_path: str = None,
136
        enable_thinking: bool = False,
MaYongQing's avatar
MaYongQing committed
137
        max_lora_rank: int = 16,
Baber Abbasi's avatar
Baber Abbasi committed
138
        **kwargs,
baberabb's avatar
baberabb committed
139
140
    ):
        super().__init__()
141

142
        if not find_spec("vllm"):
143
            raise ModuleNotFoundError(
144
145
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
146
147
            )

Baber Abbasi's avatar
Baber Abbasi committed
148
149
150
        assert max_length is None or max_model_len is None, (
            "Either max_length or max_model_len may be provided, but not both"
        )
151
        self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
152
        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
153
        self.tensor_parallel_size = int(tensor_parallel_size)
154
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
155
156
157
158
159
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
160
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
161
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
162
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
163
164
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
165
            "max_model_len": int(self._max_length) if self._max_length else None,
166
            "max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
baberabb's avatar
baberabb committed
167
168
169
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
170
            "device": str(device),
MaYongQing's avatar
MaYongQing committed
171
172
            "enable_lora": True if lora_local_path else False,
            "max_lora_rank": int(max_lora_rank),
baberabb's avatar
baberabb committed
173
        }
Baber Abbasi's avatar
Baber Abbasi committed
174
        self.model_args.update(kwargs)
175
176
177
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
178
            else int(batch_size)
179
        )
180
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
181
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
182
        else:
Baber Abbasi's avatar
Baber Abbasi committed
183
184
185
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
186
187
188
189
190
            self.model_args["distributed_executor_backend"] = (
                "ray"
                if not self.V1
                else self.model_args.get("distributed_executor_backend", None)
            )
191
192
193
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

194
        from transformers import AutoConfig
195

196
197
198
        self._config = AutoConfig.from_pretrained(
            pretrained, trust_remote_code=trust_remote_code, revision=revision
        )
baberabb's avatar
nits  
baberabb committed
199
200
201
202
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
203
            revision=tokenizer_revision,
204
            add_bos_token=add_bos_token,
baberabb's avatar
nits  
baberabb committed
205
        )
206
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
207
        self.enable_thinking = enable_thinking
208
        self.add_bos_token = add_bos_token
209
210
211
        if "gemma" in pretrained.lower():
            self.add_bos_token = True
            eval_logger.info(
212
                "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
213
214
            )

215
        if parse_version(version("vllm")) >= parse_version("0.8.3"):
216
217
218
219
220
221
222
            kwargs_resolve_hf_chat_template = {
                "tokenizer": self.tokenizer,
                "chat_template": None,
                "tools": None,
            }

            if parse_version(version("vllm")) >= parse_version("0.9.0"):
223
224
225
226
227
228
229
230
231
232
233
                if self.data_parallel_size <= 1:
                    kwargs_resolve_hf_chat_template["model_config"] = (
                        self.model.llm_engine.model_config
                    )
                else:
                    from vllm.engine.arg_utils import EngineArgs

                    engine_args = EngineArgs(**self.model_args)
                    model_config = engine_args.create_model_config()

                    kwargs_resolve_hf_chat_template["model_config"] = model_config
234
235
236
237
238
239
240
241
242
243

            # https://github.com/vllm-project/vllm/pull/18259
            if (
                "trsut_remote_code"
                in inspect.signature(resolve_hf_chat_template).parameters
            ):
                kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
            else:
                kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code

244
            self.hf_chat_template = resolve_hf_chat_template(
245
                **kwargs_resolve_hf_chat_template
246
247
248
            )
        else:
            self.hf_chat_template = None
249

250
251
252
253
254
        self.custom_prefix_token_id = prefix_token_id
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
255

baberabb's avatar
baberabb committed
256
257
        self._max_gen_toks = max_gen_toks

bcicc's avatar
bcicc committed
258
        if lora_local_path is not None:
Baber Abbasi's avatar
Baber Abbasi committed
259
260
261
            assert parse_version(version("vllm")) > parse_version("0.3.0"), (
                "lora adapters only compatible with vllm > v0.3.0."
            )
bcicc's avatar
bcicc committed
262
263
264
265
            self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
        else:
            self.lora_request = None

baberabb's avatar
baberabb committed
266
267
268
269
270
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

271
272
273
274
275
276
277
278
279
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

baberabb's avatar
baberabb committed
280
281
282
283
    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
284
285
286
287
288
289
290
291
292
293
294
295
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
296
297
298
299
300

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

Baber Abbasi's avatar
Baber Abbasi committed
301
302
303
    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
    ) -> str:
304
305
306
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        try:
            chat_templated = self.tokenizer.apply_chat_template(
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
            )
        except jinja2.exceptions.TemplateError:
            eval_logger.warning(
                "Failed to apply chat template. removing the system role in chat history."
            )
            chat_templated = self.tokenizer.apply_chat_template(
                [msg for msg in chat_history if msg["role"] != "system"],
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
            )
328

Baber Abbasi's avatar
Baber Abbasi committed
329
330
        return chat_templated

331
332
333
334
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

baberabb's avatar
baberabb committed
335
336
    def tok_encode(
        self,
337
338
339
340
341
        string: Union[str, List[str]],
        left_truncate_len: int = None,
        add_special_tokens: bool = False,
        truncation: bool = False,
    ) -> Union[List[int], List[List[int]]]:
342
343
        if not add_special_tokens:
            add_special_tokens = False or self.add_bos_token
344
345
346
347
348
349
        encoding: Union[List[List[int]], List[int]] = self.tokenizer(
            string,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            return_attention_mask=False,
        ).input_ids
baberabb's avatar
baberabb committed
350
351
352

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
353
354
355
356
            if not isinstance(string, str):
                encoding = [enc[-left_truncate_len:] for enc in encoding]
            else:
                encoding = encoding[-left_truncate_len:]
baberabb's avatar
baberabb committed
357
358
359
360
361

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
362
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
363
364
365
366
367
368
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
369
            kwargs = self.modify_gen_kwargs(kwargs)
baberabb's avatar
baberabb committed
370
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
371
        else:
baberabb's avatar
baberabb committed
372
            sampling_params = SamplingParams(
373
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
baberabb's avatar
baberabb committed
374
            )
375
        if self.data_parallel_size > 1 and not self.V1:
Baber Abbasi's avatar
Baber Abbasi committed
376
            # vLLM hangs if resources are set in ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
377
378
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
379
            @ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
380
            def run_inference_one_model(
381
                model_args: dict,
Baber Abbasi's avatar
Baber Abbasi committed
382
                sampling_params: SamplingParams,
383
384
                requests: List[List[int]],
                lora_request: LoRARequest,
Baber Abbasi's avatar
Baber Abbasi committed
385
386
387
            ):
                llm = LLM(**model_args)
                return llm.generate(
388
389
390
                    prompt_token_ids=requests,
                    sampling_params=sampling_params,
                    lora_request=lora_request,
Baber Abbasi's avatar
Baber Abbasi committed
391
392
                )

393
394
395
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
396
397
398
399
            inputs = (
                (self.model_args, sampling_params, req, self.lora_request)
                for req in requests
            )
Baber Abbasi's avatar
Baber Abbasi committed
400
401
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
402
403
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
404
            # flatten results
405
            return undistribute(results)
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        elif self.data_parallel_size > 1:
            # based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
            dp_size = self.data_parallel_size
            dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
            dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()

            requests = (list(x) for x in distribute(self.data_parallel_size, requests))

            procs, resq = [], Queue()
            # We use Process as it is non-daemonic
            try:
                for rank, req in enumerate(requests):
                    proc = Process(
                        target=_vllm_mp_worker,
                        args=(
                            self.model_args.copy(),
                            sampling_params,
                            req,
                            self.lora_request,
                            resq,
                            dp_size,
                            rank,
                            dp_master_port,
                            dp_master_ip,
                        ),
                    )
                    proc.start()
                    procs.append(proc)

                # Collect results
                rank_res = {}
                while len(rank_res) < len(procs):
                    try:
                        rank, result = resq.get(timeout=30)
                        if isinstance(result, dict) and "error" in result:
                            raise RuntimeError(result["error"])
                        rank_res[rank] = result
                    except Empty:
                        dead_procs = [
                            idx
                            for idx, p in enumerate(procs)
                            if not p.is_alive() and idx not in rank_res
                        ]
                        if dead_procs:
                            raise RuntimeError(
                                f"Worker processes {dead_procs} died unexpectedly"
                            )
                        continue

                results = [rank_res[i] for i in range(len(procs))]
                return undistribute(results)

            # cleanup
            finally:
                try:
                    resq.close()
                    resq.join_thread()
                except Exception:
                    eval_logger.debug(
                        "Failed to close vllm DP results queue", exc_info=True
                    )
                for proc in procs:
                    proc.join(timeout=10)
                    if proc.is_alive():
                        proc.terminate()
                        proc.join(timeout=5)
                        if proc.is_alive():
                            proc.kill()
baberabb's avatar
baberabb committed
474

475
476
477
478
479
480
481
482
        else:
            outputs = self.model.generate(
                prompt_token_ids=requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )
            return outputs
baberabb's avatar
baberabb committed
483

484
485
486
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        adaptive_batch_size = None
        if self.batch_size == "auto":
            adaptive_batch_size = len(requests)

        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
        ):
            rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
baberabb's avatar
baberabb committed
502
                map(
503
504
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
505
                        token_list=self.tok_encode(string),
506
507
                        prefix_token=self.prefix_token_id,
                        # max_seq_len - (1 for context)
baberabb's avatar
baberabb committed
508
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
509
510
511
512
513
                        context_len=1,
                    ),
                )
            )

514
515
            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            windows = [(None,) + x for x in rolling_token_windows]
baberabb's avatar
baberabb committed
516

517
518
519
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
baberabb's avatar
baberabb committed
520

521
522
523
524
525
526
        all_nlls = []
        batch_size = adaptive_batch_size or int(self.batch_size)
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)
baberabb's avatar
baberabb committed
527

528
529
530
531
532
533
            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
            )
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
550

baberabb's avatar
baberabb committed
551
552
        return loglikelihoods

553
554
555
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
556
        res = []
baberabb's avatar
baberabb committed
557
558
559

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
560
561
562
        context_encoding: List[List[int]] = self.tok_encode(
            context, add_special_tokens=self.add_bos_token
        )
baberabb's avatar
baberabb committed
563
564
565
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
566
567
568
569
570
571
572
573

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
574
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
575
576
577
578

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
579
        re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
580
581
582
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
583

584
585
        pbar = tqdm(
            total=len(requests),
586
            disable=(disable_tqdm or (self.rank != 0)),
587
588
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
589
        # for each different set of kwargs, we execute all requests, by batch.
590
        eos = self.tokenizer.decode(self.eot_token_id)
591
592
593
594
595
596
597
598
599
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
600
601
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
602
603
            else:
                raise ValueError(
604
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
baberabb's avatar
baberabb committed
605
                )
606
607
608
609
610
611
612
613
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            # max len for inputs = max length, minus room to generate the max new tokens
            max_ctx_len = self.max_length - max_gen_toks
614
615
616
617
618
619
            all_lengths = [len(x) for x in context_encoding]
            for length in all_lengths:
                if length > max_ctx_len:
                    eval_logger.warning(
                        f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
                    )
620
621
622
623
624
625
626
627
628
629
            context_encoding = [x[-max_ctx_len:] for x in context_encoding]

            # perform batched generation
            cont = self._model_generate(
                requests=context_encoding,
                generate=True,
                max_tokens=max_gen_toks,
                stop=until,
                **kwargs,
            )
baberabb's avatar
baberabb committed
630

631
632
633
            # cache generations
            for output, context in zip(cont, context):
                generated_text = output.outputs[0].text
634
635
636
637
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                for term in until:
                    if len(term) > 0:
                        generated_text = generated_text.split(term)[0]
638
639
640
641
642
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
643
644

        pbar.close()
645
646
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
647
648

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
649
650
651
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
652
653
654
655
656
657
658
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

659
660
661
662
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
663
        )
664

665
666
667
668
669
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
670
        for chunk in chunks:
671
            inputs = []
baberabb's avatar
baberabb committed
672
673
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
674
675
                if (
                    full_length := len(context_enc + continuation_enc)
676
                ) > self.max_length:
677
678
679
                    eval_logger.warning(
                        f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
                    )
baberabb's avatar
baberabb committed
680
681
682
683
684
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

685
                inputs.append(inp)
baberabb's avatar
baberabb committed
686
687
                ctxlens.append(ctxlen)

688
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
689

690
691
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
692
693
            ):
                answer = self._parse_logprobs(
694
695
696
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
697
698
699
700
701
                )

                res.append(answer)

                if cache_key is not None:
702
703
704
                    # special case: loglikelihood_rolling produces a number of loglikelihood requests
                    # all with cache key None. instead do add_partial on the per-example level
                    # in the loglikelihood_rolling() function for those.
baberabb's avatar
baberabb committed
705
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
706
                pbar.update(1)
baberabb's avatar
baberabb committed
707
708
709
710
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
711
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
712
713
714
        """Process logprobs and tokens.

        :param tokens: list
715
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
716
        :param outputs: RequestOutput
717
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
718
719
720
721
722
723
724
725
726
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

727
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
728
729
        continuation_logprobs_dicts = outputs.prompt_logprobs

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        def coerce_logprob_to_num(logprob):
            # vLLM changed the return type of logprobs from float
            # to a Logprob object storing the float value + extra data
            # (https://github.com/vllm-project/vllm/pull/3065).
            # If we are dealing with vllm's Logprob object, return
            # the logprob value stored as an attribute. Otherwise,
            # return the object itself (which should be a float
            # for older versions of vLLM).
            return getattr(logprob, "logprob", logprob)

        continuation_logprobs_dicts = [
            {
                token: coerce_logprob_to_num(logprob)
                for token, logprob in logprob_dict.items()
            }
            if logprob_dict is not None
            else None
            for logprob_dict in continuation_logprobs_dicts
        ]

baberabb's avatar
baberabb committed
750
        # Calculate continuation_logprobs
751
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
752
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
753
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
754
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
755
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
756
757
758
759
760
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
761
762
763
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
764
765
766
767
768
769
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
770
771

        return continuation_logprobs, is_greedy
772
773
774
775

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
776
        kwargs["temperature"] = kwargs.get("temperature", 0.0)
777
        do_sample = kwargs.pop("do_sample", None)
778
779
780
781
        if do_sample is False and "temperature" not in kwargs:
            eval_logger.debug(
                "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
            )
782
783
784
785
786
787
788
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs