"demo/ocr_demo.bak" did not exist on "f36c26565ec06651274178889f87599604e39cce"
vllm_causallms.py 31.2 KB
Newer Older
Baber's avatar
Baber committed
1
2
from __future__ import annotations

3
import copy
4
import gc
Lintang Sutawika's avatar
Lintang Sutawika committed
5
import logging
6
import os
Baber's avatar
Baber committed
7
from collections.abc import Sequence
Baber Abbasi's avatar
Baber Abbasi committed
8
from importlib.metadata import version
9
from importlib.util import find_spec
10
11
12
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
Baber's avatar
Baber committed
13
from typing import TYPE_CHECKING, Literal
14

15
import jinja2
16
from more_itertools import distribute
Baber Abbasi's avatar
Baber Abbasi committed
17
from packaging.version import parse as parse_version
18
19
from tqdm import tqdm

baberabb's avatar
baberabb committed
20
from lm_eval.api.instance import Instance
21
from lm_eval.api.model import TemplateLM
baberabb's avatar
baberabb committed
22
from lm_eval.api.registry import register_model
23
24
from lm_eval.models.utils import (
    Collator,
Baber's avatar
Baber committed
25
    bos_already_added,
26
27
    configure_pad_token,
    handle_stop_sequences,
28
    postprocess_generated_text,
29
30
    undistribute,
)
31
32
33
34
from lm_eval.utils import (
    get_rolling_token_windows,
    make_disjoint_window,
)
35

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
36

37
try:
38
    import ray
39
    from vllm import LLM, SamplingParams, TokensPrompt
40
    from vllm.lora.request import LoRARequest
baberabb's avatar
baberabb committed
41
    from vllm.transformers_utils.tokenizer import get_tokenizer
42
    from vllm.utils import get_open_port
43
44
45

    if parse_version(version("vllm")) >= parse_version("0.8.3"):
        from vllm.entrypoints.chat_utils import resolve_hf_chat_template
46
47
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
48

49
50
if TYPE_CHECKING:
    pass
bcicc's avatar
bcicc committed
51

Lintang Sutawika's avatar
Lintang Sutawika committed
52
eval_logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
53

baberabb's avatar
baberabb committed
54

55
56
def _vllm_mp_worker(
    model_args: dict,
Baber's avatar
Baber committed
57
    sampling_params: list[SamplingParams],
58
    requests: list[list[int]],
Baber's avatar
Baber committed
59
60
    lora_request: LoRARequest,
    result_queue: Queue,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    dp_size: int,
    local_dp_rank: int,
    dp_master_port: int,
    dp_master_ip: str = "127.0.0.1",
) -> None:
    """
    Worker process for vLLM multiprocessing.
    Initializes a vLLM engine, processes requests, and puts results or errors
    onto the result_queue.
    """

    if not requests:
        result_queue.put((local_dp_rank, []))
        return None

    os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)

    llm = None
    try:
        llm = LLM(**model_args)
        res = llm.generate(
85
            [TokensPrompt(prompt_token_ids=request) for request in requests],
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            sampling_params=sampling_params,
            lora_request=lora_request,
        )
        # Give engines time to pause their processing loops before exiting."
        sleep(1)
        result_queue.put((local_dp_rank, res))

    except Exception as e:
        error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
        eval_logger.error(error_message, exc_info=True)
        result_queue.put((local_dp_rank, {"error": error_message}))

    finally:
        if llm is not None:
            try:
                del llm
                gc.collect()
            except Exception as e_cleanup:
                eval_logger.warning(
                    f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
                    exc_info=True,
                )

    return None


baberabb's avatar
baberabb committed
112
@register_model("vllm")
113
class VLLM(TemplateLM):
baberabb's avatar
baberabb committed
114
115
116
117
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
118
        pretrained: str,
baberabb's avatar
baberabb committed
119
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
Baber's avatar
Baber committed
120
121
122
        revision: str | None = None,
        trust_remote_code: bool | None = False,
        tokenizer: str | None = None,
baberabb's avatar
baberabb committed
123
        tokenizer_mode: Literal["auto", "slow"] = "auto",
Baber's avatar
Baber committed
124
125
126
        tokenizer_revision: str | None = None,
        add_bos_token: bool | None = False,
        prefix_token_id: int | None = None,
baberabb's avatar
baberabb committed
127
        tensor_parallel_size: int = 1,
Baber's avatar
Baber committed
128
        quantization: str | None = None,
baberabb's avatar
baberabb committed
129
130
        max_gen_toks: int = 256,
        swap_space: int = 4,
Baber's avatar
Baber committed
131
        batch_size: str | int = 1,
baberabb's avatar
baberabb committed
132
        max_batch_size=None,
baberabb's avatar
baberabb committed
133
        max_length: int = None,
134
        max_model_len: int = None,
baberabb's avatar
baberabb committed
135
        seed: int = 1234,
136
        gpu_memory_utilization: float = 0.9,
137
        data_parallel_size: int = 1,
bcicc's avatar
bcicc committed
138
        lora_local_path: str = None,
139
140
        # VLLM: enable thinking tags in the prompt.
        enable_thinking: bool = True,
Baber's avatar
Baber committed
141
        chat_template_args: dict | None = None,
142
        # End marker for thinking tags - splits to get response after this token (if provided).
Baber's avatar
Baber committed
143
        think_end_token: str | None = None,
MaYongQing's avatar
MaYongQing committed
144
        max_lora_rank: int = 16,
Baber Abbasi's avatar
Baber Abbasi committed
145
        **kwargs,
baberabb's avatar
baberabb committed
146
147
    ):
        super().__init__()
148

149
        if not find_spec("vllm"):
150
            raise ModuleNotFoundError(
151
152
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
                "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
153
154
            )

Baber Abbasi's avatar
Baber Abbasi committed
155
156
157
        assert max_length is None or max_model_len is None, (
            "Either max_length or max_model_len may be provided, but not both"
        )
Baber Abbasi's avatar
Baber Abbasi committed
158
        kwargs.pop("device", None)
159
        self.think_end_token = think_end_token
160
        self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
161
        self._max_length = max_model_len if max_model_len is not None else max_length
baberabb's avatar
baberabb committed
162
        self.tensor_parallel_size = int(tensor_parallel_size)
163
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
164
165
166
167
168
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
169
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
170
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
171
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
172
173
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
174
            "max_model_len": int(self._max_length) if self._max_length else None,
175
            "max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
baberabb's avatar
baberabb committed
176
177
178
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
MaYongQing's avatar
MaYongQing committed
179
180
            "enable_lora": True if lora_local_path else False,
            "max_lora_rank": int(max_lora_rank),
baberabb's avatar
baberabb committed
181
        }
Baber Abbasi's avatar
Baber Abbasi committed
182
        self.model_args.update(kwargs)
183
184
185
        self.batch_size = (
            "auto"
            if isinstance(batch_size, str) and "auto" in batch_size
186
            else int(batch_size)
187
        )
188
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
189
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
190
        else:
Baber Abbasi's avatar
Baber Abbasi committed
191
192
193
            eval_logger.warning(
                "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
            )
194
195
196
197
198
            self.model_args["distributed_executor_backend"] = (
                "ray"
                if not self.V1
                else self.model_args.get("distributed_executor_backend", None)
            )
199
200
201
            self.batch_size = "auto"
            eval_logger.info("Manual batching is not compatible with data parallelism.")

Baber's avatar
Baber committed
202
        self.add_bos_token = add_bos_token
203

204
        from transformers import AutoConfig
205

206
207
208
        self._config = AutoConfig.from_pretrained(
            pretrained, trust_remote_code=trust_remote_code, revision=revision
        )
baberabb's avatar
nits  
baberabb committed
209
210
211
212
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
213
            revision=tokenizer_revision,
Baber's avatar
Baber committed
214
215
216
217
218
            **(
                {"add_bos_token": self.add_bos_token}
                if self.add_bos_token is not None
                else {}
            ),
baberabb's avatar
nits  
baberabb committed
219
        )
220
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
221
        self.chat_template_args = chat_template_args or {}
222
        self.enable_thinking = self.chat_template_args.pop(
223
224
            "enable_thinking", enable_thinking
        )
225

226
        if parse_version(version("vllm")) >= parse_version("0.8.3"):
227
228
229
230
231
232
233
            kwargs_resolve_hf_chat_template = {
                "tokenizer": self.tokenizer,
                "chat_template": None,
                "tools": None,
            }

            if parse_version(version("vllm")) >= parse_version("0.9.0"):
234
235
236
237
238
239
240
241
242
243
244
                if self.data_parallel_size <= 1:
                    kwargs_resolve_hf_chat_template["model_config"] = (
                        self.model.llm_engine.model_config
                    )
                else:
                    from vllm.engine.arg_utils import EngineArgs

                    engine_args = EngineArgs(**self.model_args)
                    model_config = engine_args.create_model_config()

                    kwargs_resolve_hf_chat_template["model_config"] = model_config
245
246
247
            else:
                kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code

248
            self.hf_chat_template = resolve_hf_chat_template(
249
                **kwargs_resolve_hf_chat_template
250
251
252
            )
        else:
            self.hf_chat_template = None
253

254
255
256
257
258
        self.custom_prefix_token_id = prefix_token_id
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
259

baberabb's avatar
baberabb committed
260
261
        self._max_gen_toks = max_gen_toks

bcicc's avatar
bcicc committed
262
        if lora_local_path is not None:
Baber Abbasi's avatar
Baber Abbasi committed
263
264
265
            assert parse_version(version("vllm")) > parse_version("0.3.0"), (
                "lora adapters only compatible with vllm > v0.3.0."
            )
bcicc's avatar
bcicc committed
266
267
268
269
            self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
        else:
            self.lora_request = None

baberabb's avatar
baberabb committed
270
    @property
Baber's avatar
Baber committed
271
    def eot_token_id(self) -> int | None:
baberabb's avatar
baberabb committed
272
273
274
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

275
276
277
278
279
280
281
282
283
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

baberabb's avatar
baberabb committed
284
285
    @property
    def max_length(self):
286
287
288
289
        return 8096 if self._max_length > 8096 else self._max_length

    @property
    def _max_length(self):
baberabb's avatar
baberabb committed
290
291
        if self._max_length:  # if max length manually set, return it
            return self._max_length
292
293
294
295
296
297
298
299
300
301
302
303
        if self.data_parallel_size <= 1:
            return self.model.llm_engine.model_config.max_model_len
        else:
            seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
            for attr in seqlen_config_attrs:
                if hasattr(self._config, attr):
                    return getattr(self._config, attr)
            if hasattr(self.tokenizer, "model_max_length"):
                if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                    return self._DEFAULT_MAX_LENGTH
                return self.tokenizer.model_max_length
            return self._DEFAULT_MAX_LENGTH
baberabb's avatar
baberabb committed
304
305
306
307
308

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

Baber Abbasi's avatar
Baber Abbasi committed
309
    def apply_chat_template(
Baber's avatar
Baber committed
310
        self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
Baber Abbasi's avatar
Baber Abbasi committed
311
    ) -> str:
312
313
314
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
315
316
317
318
319
320
321
322
        try:
            chat_templated = self.tokenizer.apply_chat_template(
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
323
                **self.chat_template_args,
324
325
326
327
328
329
330
331
332
333
334
335
            )
        except jinja2.exceptions.TemplateError:
            eval_logger.warning(
                "Failed to apply chat template. removing the system role in chat history."
            )
            chat_templated = self.tokenizer.apply_chat_template(
                [msg for msg in chat_history if msg["role"] != "system"],
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
                chat_template=self.hf_chat_template,
                enable_thinking=self.enable_thinking,
336
                **self.chat_template_args,
337
            )
338

Baber Abbasi's avatar
Baber Abbasi committed
339
340
        return chat_templated

341
342
343
344
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

baberabb's avatar
baberabb committed
345
346
    def tok_encode(
        self,
Baber's avatar
Baber committed
347
348
349
        string: str | list[str],
        left_truncate_len: int | None = None,
        add_special_tokens: bool | None = None,
350
        truncation: bool = False,
Baber's avatar
Baber committed
351
352
353
354
355
356
357
    ) -> list[int] | list[list[int]]:
        add_special_kwargs = (
            {"add_special_tokens": add_special_tokens or self.add_bos_token}
            if (add_special_tokens is not None or self.add_bos_token is not None)
            else {}
        )
        # handle chat template
Baber's avatar
Baber committed
358
        if bos_already_added(
Baber's avatar
Baber committed
359
360
            string[0] if isinstance(string, Sequence) else string,
            self.tokenizer.bos_token,
Baber's avatar
Baber committed
361
        ):
Baber's avatar
Baber committed
362
            add_special_kwargs = {"add_special_tokens": False}
Baber's avatar
Baber committed
363

Baber's avatar
Baber committed
364
        encoding: list[list[int]] | list[int] = self.tokenizer(
365
366
367
            string,
            truncation=truncation,
            return_attention_mask=False,
Baber's avatar
Baber committed
368
            **add_special_kwargs,
369
        ).input_ids
baberabb's avatar
baberabb committed
370
371
372

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
373
374
375
376
            if not isinstance(string, str):
                encoding = [enc[-left_truncate_len:] for enc in encoding]
            else:
                encoding = encoding[-left_truncate_len:]
baberabb's avatar
baberabb committed
377
378
379
380
381

        return encoding

    def _model_generate(
        self,
Baber's avatar
Baber committed
382
        requests: list[list[int]],
baberabb's avatar
baberabb committed
383
        generate: bool = False,
Baber's avatar
Baber committed
384
        sampling_params: list[SamplingParams] | SamplingParams | None = None,
baberabb's avatar
baberabb committed
385
    ):
386
        if not generate or sampling_params is None:
baberabb's avatar
baberabb committed
387
            sampling_params = SamplingParams(
388
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
baberabb's avatar
baberabb committed
389
            )
Baber's avatar
Baber committed
390
        if not isinstance(sampling_params, list):
391
            sampling_params = [sampling_params] * len(requests)
392
        if self.data_parallel_size > 1 and not self.V1:
Baber Abbasi's avatar
Baber Abbasi committed
393
            # vLLM hangs if resources are set in ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
394
395
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
396
            @ray.remote
Baber Abbasi's avatar
Baber Abbasi committed
397
            def run_inference_one_model(
398
                model_args: dict,
Baber's avatar
Baber committed
399
400
401
                sampling_params: list[SamplingParams],
                requests: list[list[int]],
                lora_request: LoRARequest,
Baber Abbasi's avatar
Baber Abbasi committed
402
403
404
            ):
                llm = LLM(**model_args)
                return llm.generate(
405
                    [TokensPrompt(prompt_token_ids=request) for request in requests],
406
407
                    sampling_params=sampling_params,
                    lora_request=lora_request,
Baber Abbasi's avatar
Baber Abbasi committed
408
409
                )

410
411
412
            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
413
414
415
            sampling_params = [
                list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
            ]
416
            inputs = (
417
418
                (self.model_args, sp, req, self.lora_request)
                for req, sp in zip(requests, sampling_params)
419
            )
Baber Abbasi's avatar
Baber Abbasi committed
420
421
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
422
423
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
baberabb's avatar
baberabb committed
424
            # flatten results
425
            return undistribute(results)
426
427
428
429
430
431
432
        elif self.data_parallel_size > 1:
            # based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
            dp_size = self.data_parallel_size
            dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
            dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()

            requests = (list(x) for x in distribute(self.data_parallel_size, requests))
433
434
435
            sampling_params = (
                list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
            )
436
437
438
            procs, resq = [], Queue()
            # We use Process as it is non-daemonic
            try:
Vineeth's avatar
Vineeth committed
439
                for rank, (req, sp) in enumerate(zip(requests, sampling_params)):
440
441
442
443
                    proc = Process(
                        target=_vllm_mp_worker,
                        args=(
                            self.model_args.copy(),
444
                            sp,
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
                            req,
                            self.lora_request,
                            resq,
                            dp_size,
                            rank,
                            dp_master_port,
                            dp_master_ip,
                        ),
                    )
                    proc.start()
                    procs.append(proc)

                # Collect results
                rank_res = {}
                while len(rank_res) < len(procs):
                    try:
                        rank, result = resq.get(timeout=30)
                        if isinstance(result, dict) and "error" in result:
                            raise RuntimeError(result["error"])
                        rank_res[rank] = result
                    except Empty:
                        dead_procs = [
                            idx
                            for idx, p in enumerate(procs)
                            if not p.is_alive() and idx not in rank_res
                        ]
                        if dead_procs:
                            raise RuntimeError(
                                f"Worker processes {dead_procs} died unexpectedly"
                            )
                        continue

                results = [rank_res[i] for i in range(len(procs))]
                return undistribute(results)

            # cleanup
            finally:
                try:
                    resq.close()
                    resq.join_thread()
                except Exception:
                    eval_logger.debug(
                        "Failed to close vllm DP results queue", exc_info=True
                    )
                for proc in procs:
                    proc.join(timeout=10)
                    if proc.is_alive():
                        proc.terminate()
                        proc.join(timeout=5)
                        if proc.is_alive():
                            proc.kill()
baberabb's avatar
baberabb committed
496

497
498
        else:
            outputs = self.model.generate(
499
                [TokensPrompt(prompt_token_ids=request) for request in requests],
500
501
502
503
504
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )
            return outputs
baberabb's avatar
baberabb committed
505

506
    def loglikelihood_rolling(
Baber's avatar
Baber committed
507
508
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[float]:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
        adaptive_batch_size = None
        if self.batch_size == "auto":
            adaptive_batch_size = len(requests)

        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
        ):
Baber's avatar
Baber committed
523
            rolling_token_windows: list[tuple[list[int], list[int]]] = list(
baberabb's avatar
baberabb committed
524
                map(
525
526
                    make_disjoint_window,
                    get_rolling_token_windows(
baberabb's avatar
baberabb committed
527
                        token_list=self.tok_encode(string),
528
529
                        prefix_token=self.prefix_token_id,
                        # max_seq_len - (1 for context)
baberabb's avatar
baberabb committed
530
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
531
532
533
534
535
                        context_len=1,
                    ),
                )
            )

536
537
            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            windows = [(None,) + x for x in rolling_token_windows]
baberabb's avatar
baberabb committed
538

539
540
541
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
baberabb's avatar
baberabb committed
542

543
544
545
546
547
548
        all_nlls = []
        batch_size = adaptive_batch_size or int(self.batch_size)
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)
baberabb's avatar
baberabb committed
549

550
551
552
553
554
555
            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
            )
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
572

baberabb's avatar
baberabb committed
573
574
        return loglikelihoods

575
    def generate_until(
Baber's avatar
Baber committed
576
577
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[str]:
578
        res = []
baberabb's avatar
baberabb committed
579
580
581

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
Baber's avatar
Baber committed
582
583
        context_encoding = self.tok_encode(context)
        reqs = [
baberabb's avatar
baberabb committed
584
585
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
586
587
588
589
590
591
592
593

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
594
            return -len(_requests[0][1]), _requests[0][0]
baberabb's avatar
baberabb committed
595

596
        re_ords = Collator(
Baber's avatar
Baber committed
597
            reqs,
598
599
600
            _collate_gen,
            group_by=None,
        )
601
602
603
        chunks = re_ords.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
        )
baberabb's avatar
baberabb committed
604

605
        pbar = tqdm(
Baber's avatar
Baber committed
606
            total=len(reqs),
607
            disable=(disable_tqdm or (self.rank != 0)),
608
609
            desc="Running generate_until requests",
        )
baberabb's avatar
baberabb committed
610
        # for each different set of kwargs, we execute all requests, by batch.
611
        eos = self.tokenizer.decode(self.eot_token_id)
612
613
614
        for chunk in chunks:
            context_and_encoding, all_gen_kwargs = zip(*chunk)
            context, context_encoding = zip(*context_and_encoding)
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
            context_encoding_truncated = []
            sampling_params = []
            for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
                # unpack our keyword arguments.
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    # add EOS token to stop sequences
                    until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
                else:
                    raise ValueError(
                        f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
                    )
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks

                # set the max length in tokens of inputs ("context_enc")
                # max len for inputs = max length, minus room to generate the max new tokens
634
635
636
                default_length = len(x) + max_gen_toks
                if default_length > self.max_length:
                    max_gen_toks = self.max_length - len(x)
637
638
639
640
641
642
                    context_encoding_truncated.append(x)
                # create sampling params
                kwargs = self.modify_gen_kwargs(kwargs)
                sampling_params.append(
                    SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
                )
643
644
645

            # perform batched generation
            cont = self._model_generate(
646
                requests=context_encoding_truncated,
647
                generate=True,
648
                sampling_params=sampling_params,
649
            )
baberabb's avatar
baberabb committed
650

651
652
            # cache generations
            for output, context in zip(cont, context):
653
                generated_text: str = output.outputs[0].text
654
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
655
656
657
                generated_text = postprocess_generated_text(
                    generated_text, until, self.think_end_token
                )
658
659
660
661
662
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
baberabb's avatar
baberabb committed
663
664

        pbar.close()
665
666
        # reorder all group of results back to original unsorted form
        return re_ords.get_original(res)
baberabb's avatar
baberabb committed
667
668

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
669
        self,
Baber's avatar
Baber committed
670
        requests: list[tuple[tuple[str, str], list[int], list[int]]],
baberabb's avatar
baberabb committed
671
        disable_tqdm: bool = False,
Baber's avatar
Baber committed
672
    ) -> list[tuple[float, bool]]:
baberabb's avatar
baberabb committed
673
674
675
676
677
678
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

679
680
681
682
        # Reorder requests by length and batch
        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(
            n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
baberabb's avatar
baberabb committed
683
        )
684

685
686
687
688
689
        pbar = tqdm(
            total=len(requests),
            disable=disable_tqdm,
            desc="Running loglikelihood requests",
        )
baberabb's avatar
baberabb committed
690
        for chunk in chunks:
691
            inputs = []
baberabb's avatar
baberabb committed
692
693
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
694
695
                if (
                    full_length := len(context_enc + continuation_enc)
696
                ) > self.max_length:
697
698
699
                    eval_logger.warning(
                        f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
                    )
baberabb's avatar
baberabb committed
700
701
702
703
704
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

705
                inputs.append(inp)
baberabb's avatar
baberabb committed
706
707
                ctxlens.append(ctxlen)

708
            outputs = self._model_generate(requests=inputs, generate=False)
baberabb's avatar
baberabb committed
709

710
711
            for output, ctxlen, (cache_key, _, _), inp in zip(
                outputs, ctxlens, chunk, inputs
baberabb's avatar
baberabb committed
712
713
            ):
                answer = self._parse_logprobs(
714
715
716
                    tokens=inp,
                    outputs=output,
                    ctxlen=ctxlen,
baberabb's avatar
baberabb committed
717
718
719
720
721
                )

                res.append(answer)

                if cache_key is not None:
722
723
724
                    # special case: loglikelihood_rolling produces a number of loglikelihood requests
                    # all with cache key None. instead do add_partial on the per-example level
                    # in the loglikelihood_rolling() function for those.
baberabb's avatar
baberabb committed
725
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
726
                pbar.update(1)
baberabb's avatar
baberabb committed
727
728
729
730
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
Baber's avatar
Baber committed
731
    def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
baberabb's avatar
baberabb committed
732
733
734
        """Process logprobs and tokens.

        :param tokens: list
735
            Input tokens (potentially left-truncated)
baberabb's avatar
bugfix  
baberabb committed
736
        :param outputs: RequestOutput
737
            Contains prompt_logprobs
baberabb's avatar
baberabb committed
738
739
740
741
742
743
744
745
746
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

747
        # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
baberabb's avatar
bugfix  
baberabb committed
748
749
        continuation_logprobs_dicts = outputs.prompt_logprobs

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
        def coerce_logprob_to_num(logprob):
            # vLLM changed the return type of logprobs from float
            # to a Logprob object storing the float value + extra data
            # (https://github.com/vllm-project/vllm/pull/3065).
            # If we are dealing with vllm's Logprob object, return
            # the logprob value stored as an attribute. Otherwise,
            # return the object itself (which should be a float
            # for older versions of vLLM).
            return getattr(logprob, "logprob", logprob)

        continuation_logprobs_dicts = [
            {
                token: coerce_logprob_to_num(logprob)
                for token, logprob in logprob_dict.items()
            }
            if logprob_dict is not None
            else None
            for logprob_dict in continuation_logprobs_dicts
        ]

baberabb's avatar
baberabb committed
770
        # Calculate continuation_logprobs
771
        # assume ctxlen always >= 1
baberabb's avatar
baberabb committed
772
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
773
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
774
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
775
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
776
777
778
779
780
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
781
782
783
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
784
785
786
787
788
789
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
790
791

        return continuation_logprobs, is_greedy
792
793
794
795

    @staticmethod
    def modify_gen_kwargs(kwargs: dict) -> dict:
        # sampling_params
796
        kwargs["temperature"] = kwargs.get("temperature", 0.0)
797
        do_sample = kwargs.pop("do_sample", None)
798
799
800
801
        if do_sample is False and "temperature" not in kwargs:
            eval_logger.debug(
                "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
            )
802
803
804
805
806
807
808
            kwargs["temperature"] = 0.0
        # hf defaults
        kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
        kwargs["spaces_between_special_tokens"] = kwargs.get(
            "spaces_between_special_tokens", False
        )
        return kwargs