huggingface.py 64.9 KB
Newer Older
1
2
from __future__ import annotations

3
import copy
Lintang Sutawika's avatar
Lintang Sutawika committed
4
import logging
5
import os
Baber Abbasi's avatar
Baber Abbasi committed
6
from collections.abc import Iterator, Sequence
Jeevan's avatar
Jeevan committed
7
from datetime import timedelta
8
from pathlib import Path
Baber Abbasi's avatar
Baber Abbasi committed
9
from typing import TYPE_CHECKING, Any, Literal
10

11
import jinja2
12
import torch
13
import torch.nn.functional as F
14
import transformers
Jeevan's avatar
Jeevan committed
15
16
17
18
19
from accelerate import (
    Accelerator,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)
Nathan Habib's avatar
Nathan Habib committed
20
from accelerate.utils import get_max_memory
21
from huggingface_hub import HfApi
22
from packaging import version
Baber Abbasi's avatar
Baber Abbasi committed
23
from packaging.version import parse as vparse
24
from tqdm import tqdm
25
26
27
28
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
29
30

from lm_eval import utils
31
from lm_eval.api.model import TemplateLM
32
from lm_eval.api.registry import register_model
33
34
35
from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
36
    configure_pad_token,
37
    get_dtype,
38
    handle_stop_sequences,
39
    pad_and_concat,
40
    postprocess_generated_text,
41
42
    stop_sequences_criteria,
)
43

44

45
if TYPE_CHECKING:
Baber Abbasi's avatar
Baber Abbasi committed
46
47
48
    from transformers.quantizers.auto import AutoQuantizationConfig

    from lm_eval.api.instance import Instance
49

Lintang Sutawika's avatar
Lintang Sutawika committed
50
eval_logger = logging.getLogger(__name__)
Baber Abbasi's avatar
Baber Abbasi committed
51
TOKENIZER_INFINITY = 1000000000000000019884624838656
52

lintangsutawika's avatar
lintangsutawika committed
53

54
@register_model("hf-auto", "hf", "huggingface")
55
class HFLM(TemplateLM):
Baber Abbasi's avatar
Baber Abbasi committed
56
    """An abstracted Huggingface model class. Enables usage with both models of
57
58
59
60
61
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

62
    AUTO_MODEL_CLASS = None
63
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
64

65
66
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
67
        pretrained: str | transformers.PreTrainedModel,
68
        backend: Literal["default", "causal", "seq2seq"] = "default",
Baber Abbasi's avatar
Baber Abbasi committed
69
        # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
Baber Abbasi's avatar
Baber Abbasi committed
70
        revision: str | None = "main",
71
        subfolder: str = "",
Baber Abbasi's avatar
Baber Abbasi committed
72
73
74
75
76
        tokenizer: str
        | transformers.PreTrainedTokenizer
        | transformers.PreTrainedTokenizerFast
        | None = None,
        truncation: bool | None = False,
Baber Abbasi's avatar
Baber Abbasi committed
77
        logits_cache: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
78
79
80
81
82
83
84
85
86
        max_length: int | None = None,
        device: str | None = "cuda",
        dtype: str | torch.dtype | None = "auto",
        softmax_dtype: str | torch.dtype | None = None,
        mixed_precision_dtype: str | torch.dtype | None = None,
        batch_size: int | str | None = 1,
        max_batch_size: int | None = 64,
        trust_remote_code: bool | None = False,
        use_fast_tokenizer: bool | None = True,
87
        add_bos_token: bool | None = None,
Baber Abbasi's avatar
Baber Abbasi committed
88
        prefix_token_id: int | None = None,
89
        # arguments used for splitting a model across GPUs naively.
90
        # only used if `parallelize=True`.
Baber Abbasi's avatar
Baber Abbasi committed
91
92
93
94
        parallelize: bool | None = False,
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | os.PathLike | None = "./offload",
95
        # PEFT, delta weights and quantization options
Baber Abbasi's avatar
Baber Abbasi committed
96
97
98
99
100
        peft: str | None = None,
        delta: str | None = None,
        autogptq: bool | str | None = False,
        gptqmodel: bool | None = False,
        gguf_file: str | None = None,
101
102
        # end token for thinking, either the string or int token id.
        # splits to get response after this token (if provided).
Baber Abbasi's avatar
Baber Abbasi committed
103
        think_end_token: str | int | None = None,
104
        enable_thinking: bool | None = None,
Baber Abbasi's avatar
Baber Abbasi committed
105
        chat_template_args: dict[str, Any] | None = None,
106
        **kwargs,
Ethan Smith's avatar
Ethan Smith committed
107
    ) -> None:
108
        super().__init__()
109
110
111
112
        # optionally: take in an already-initialized transformers.PreTrainedModel
        if not isinstance(pretrained, str):
            eval_logger.warning(
                "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
113
            )
Baber Abbasi's avatar
Baber Abbasi committed
114
115
116
            assert not parallelize, (
                "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
            )
117
118
119
            self._model = pretrained
            self._device = self._model.device
            self._config = self._model.config
Baber Abbasi's avatar
Baber Abbasi committed
120
            gpus = 0
121

122
        else:
123
124
125
126
            assert isinstance(device, str)
            assert isinstance(pretrained, str)
            assert isinstance(batch_size, (int, str))

Jeevan's avatar
Jeevan committed
127
128
            accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
            accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
129
130
            if accelerator.num_processes > 1:
                self.accelerator = accelerator
131

kaixuanliu's avatar
kaixuanliu committed
132
133
134
135
136
            # Detect device count based on accelerator device type
            device_type = accelerator.device.type
            if "cuda" in device_type:
                gpus = torch.cuda.device_count()
            elif "npu" in device_type:
137
                gpus = torch.npu.device_count()
kaixuanliu's avatar
kaixuanliu committed
138
139
140
141
142
            elif "xpu" in device_type:
                gpus = torch.xpu.device_count()
            else:
                # Fallback to CUDA count for compatibility
                gpus = torch.cuda.device_count()
143

Nathan Habib's avatar
Nathan Habib committed
144
            # using one process with no model parallelism
145
146
147
148
            if not (parallelize or accelerator.num_processes > 1):
                # use user-passed device
                device_list = set(
                    ["cuda", "cpu"]
149
                    + [f"cuda:{i}" for i in range(gpus)]
150
                    + ["mps", "mps:0"]
151
                    + [f"npu:{i}" for i in range(gpus)]
kaixuanliu's avatar
kaixuanliu committed
152
                    + [f"xpu:{i}" for i in range(gpus)]
153
                )
154
                if device and device in device_list:
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                    self._device = torch.device(device)
                    eval_logger.info(f"Using device '{device}'")
                    if device in ("mps", "mps:0") and version.parse(
                        torch.__version__
                    ) < version.parse("2.1"):
                        raise RuntimeError(
                            f"mps requires torch >= 2.1. You have {torch.__version__}"
                        )
                else:
                    eval_logger.info("Device not specified")
                    eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                    self._device = (
                        torch.device("cuda")
                        if torch.cuda.is_available()
                        else torch.device("cpu")
                    )
Nathan Habib's avatar
Nathan Habib committed
171
            else:  # Parallelism managed by accelerate
172
173
174
175
176
                if device != "cuda":
                    eval_logger.info(
                        f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                    )
                # TODO: include in warning that `load_in_8bit` etc. affect this too
Nathan Habib's avatar
Nathan Habib committed
177
178
179
180
181
                self._device = (
                    self.accelerator.device
                    if hasattr(self, "accelerator")
                    else torch.device(device)
                )
182

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
183
            revision = str(revision)  # cast to string if not already one
184

185
            self._get_config(
186
187
188
                pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
189
                gguf_file=gguf_file,
190
                subfolder=subfolder,
191
192
            )

193
            # determine which of 'causal' and 'seq2seq' backends to use for HF models
194
195
196
        self._get_backend(
            config=self.config, backend=backend, trust_remote_code=trust_remote_code
        )
197

198
199
200
201
202
        # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
        self._create_tokenizer(
            pretrained,
            tokenizer,
            revision=revision,
203
            subfolder=subfolder,
204
205
            trust_remote_code=trust_remote_code,
            use_fast_tokenizer=use_fast_tokenizer,
206
            gguf_file=gguf_file,
207
            add_bos_token=add_bos_token,
208
209
        )

210
211
212
213
214
215
216
        if (
            quantization_config := getattr(self.config, "quantization_config", None)
        ) is not None and isinstance(quantization_config, dict):
            from transformers.quantizers import AutoQuantizationConfig

            quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

217
218
219
220
221
222
223
224
        # if we passed `pretrained` as a string, initialize our model now
        if isinstance(pretrained, str):
            self._create_model(
                pretrained=pretrained,
                revision=revision,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                parallelize=parallelize,
225
                gpus=gpus,
226
227
228
229
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                peft=peft,
230
                delta=delta,
231
                autogptq=autogptq,
232
                gptqmodel=gptqmodel,
233
                gguf_file=gguf_file,
234
                quantization_config=quantization_config,
235
                subfolder=subfolder,
236
                **kwargs,
237
238
            )

239
        # access self._model through self.model property outside this method
240
241
242
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            self.model.tie_weights()
haileyschoelkopf's avatar
haileyschoelkopf committed
243

244
245
246
247
248
        self.think_end_token = (
            int(think_end_token)
            if (isinstance(think_end_token, str) and think_end_token.isdigit())
            else think_end_token
        )
lintangsutawika's avatar
lintangsutawika committed
249
        self.truncation = truncation
Baber Abbasi's avatar
Baber Abbasi committed
250
        self.logits_cache = logits_cache
251
        self.vocab_size = self.tokenizer.vocab_size
252
        # select (or create) a pad token to use
253
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
254
255
256
257
258
        self.chat_template_args = (
            chat_template_args or {} | dict(enable_thinking=enable_thinking)
            if enable_thinking is not None
            else {}
        )
259

260
        self.add_bos_token = add_bos_token
261
262
263
264
265
266
267
268
        if self.add_bos_token is None:
            if getattr(self.tokenizer, "add_bos_token", False):
                self.add_bos_token = True
                eval_logger.info(
                    f"Tokenizer has 'add_bos_token' attribute set -- using BOS token based on tokenizer configuration for model type '{self.config.model_type}'. To control explicitly, set `add_bos_token=True|False`"
                )
            else:
                self.add_bos_token = False
269

270
        self._max_length = max_length
271
272
273
274
        self.pretrained = pretrained
        self.delta = delta
        self.peft = peft
        self.revision = revision
Benjamin Fattori's avatar
Benjamin Fattori committed
275
276
277
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size
278
279
280
        self.softmax_dtype = (
            get_dtype(softmax_dtype) if softmax_dtype is not None else None
        )
281
282
283
284
285
        self.mixed_precision_dtype = (
            get_dtype(mixed_precision_dtype)
            if mixed_precision_dtype is not None
            else None
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
286
287
288
289
290
291
292

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
293

294
        if isinstance(pretrained, str):
Baber Abbasi's avatar
Baber Abbasi committed
295
296
297
            if (gpus >= 1 or str(self.device) == "mps") and not (
                parallelize or autogptq or hasattr(self, "accelerator")
            ):
Nathan Habib's avatar
Nathan Habib committed
298
                # TODO: can remove this whole snippet except in the mps case, perhaps?
Baber Abbasi's avatar
Baber Abbasi committed
299
300
301
302
303
304
305
306
307
                # place model onto device requested manually,
                # if not using HF Accelerate or device_map
                # or any other option that preloads model onto device
                try:
                    self.model.to(self.device)
                except ValueError:
                    eval_logger.debug(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
                    )
308
309
            # multigpu data-parallel support when launched with accelerate
            if gpus > 1:
Nathan Habib's avatar
Nathan Habib committed
310
311
312
313
                if accelerator.num_processes > 1:
                    if parallelize:
                        eval_logger.warning(
                            "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
314
                        )
Nathan Habib's avatar
Nathan Habib committed
315
                    elif gpus > accelerator.num_processes:
316
317
318
319
320
321
                        eval_logger.warning(
                            "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                            "If you would like to use data parallelism, please launch the script "
                            "with 'accelerate launch *script*'. "
                            f"Current run will proceed with {accelerator.num_processes} devices."
                        )
Nathan Habib's avatar
Nathan Habib committed
322
323
324
325
326
                        if self.accelerator.is_local_main_process:
                            eval_logger.info(
                                f"Using {gpus} devices with data parallelism"
                            )

327
                    self._device = torch.device(f"{accelerator.device}")
328
                    self.accelerator = accelerator
329

330
331
                    self._rank = self.accelerator.local_process_index
                    self._world_size = self.accelerator.num_processes
Nathan Habib's avatar
Nathan Habib committed
332
333
334
335
                else:
                    # if we aren't launching via accelerate, ditch
                    self._rank = 0
                    self._world_size = 1
336
337
338
339
340
341
342
        else:
            # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
            eval_logger.warning(
                "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
            )
            self._rank = 0
            self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
343

344
        self.custom_prefix_token_id = prefix_token_id
345
346
347
348
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
349

Nathan Habib's avatar
Nathan Habib committed
350
351
    def _get_accelerate_args(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
352
353
354
355
356
357
        parallelize: bool | None = None,
        device_map: str | None = "auto",
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | None = "./offload",
        gpus: int | None = None,
Nathan Habib's avatar
Nathan Habib committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    ) -> dict:
        """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
        num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
        num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
        if (
            num_machines == 0
            and hasattr(self, "accelerator")
            and self.accelerator is not None
        ):
            eval_logger.info(
                "We are not in a distributed setting for accelerate. Setting model_parallel to False."
            )
            parallelize = False

        if parallelize is None:
            # If parallelism is unset by the user, we automatically assign model parallelism
            # if enough extra GPUs are available
            max_memory_all_gpus = get_max_memory()
            # We just want gpu, not cpu, max memory
            if "cpu" in max_memory_all_gpus:
                del max_memory_all_gpus["cpu"]
            parallelize = bool(num_local_processes < len(max_memory_all_gpus))
            eval_logger.info(
                f"Setting model parallel to {parallelize} since "
                f"the number of local processes is {num_local_processes} "
                f"and the number of GPUs is {len(max_memory_all_gpus)}"
            )

        args = {}
        if parallelize:  # Model parallelism will be used
            max_memory = {}
            if max_memory_per_gpu is not None:  # Using the provided memory requirements
                max_memory_per_gpu_map = {
                    device_idx: max_memory_per_gpu for device_idx in range(gpus)
                }
            else:  # Estimating the possible memory requirements
                max_memory_all_gpus = get_max_memory()
Baber Abbasi's avatar
Baber Abbasi committed
395
396
                max_memory_all_gpus.pop("cpu", None)
                if hasattr(self, "accelerator"):
Nathan Habib's avatar
Nathan Habib committed
397
398
399
400
401
402
403
                    # use only 1 / num_processes of the GPUs if we are running under accelerate launch
                    max_memory_per_gpu_map = {
                        k: v
                        for k, v in max_memory_all_gpus.items()
                        if k % num_local_processes
                        == (self.accelerator.process_index % num_local_processes)
                    }
Baber Abbasi's avatar
Baber Abbasi committed
404
405
406
                else:
                    max_memory_per_gpu_map = max_memory_all_gpus

Nathan Habib's avatar
Nathan Habib committed
407
            args["max_memory"] = max_memory_per_gpu_map
408
            args["device_map"] = "auto" if device_map is None else device_map
Nathan Habib's avatar
Nathan Habib committed
409
            eval_logger.info(
410
                f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
Nathan Habib's avatar
Nathan Habib committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
            )

            if max_cpu_memory is not None:
                max_memory["cpu"] = max_cpu_memory

            args["offload_folder"] = offload_folder
        elif (
            device_map is None
        ):  # No model parallelism, we use the default provided device for our model
            if hasattr(self, "accelerator"):
                device_map = {"": f"{self.accelerator.device}"}
            else:
                device_map = {"": str(self.device)}
            args["max_memory"] = None
            args["device_map"] = device_map
            eval_logger.info(
                f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
            )
        else:
            args["max_memory"] = None
            args["device_map"] = None
            eval_logger.info("Model parallel was set to False.")

        return args

436
437
438
439
440
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

441
442
443
444
445
446
447
448
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

449
    @property
Baber Abbasi's avatar
Baber Abbasi committed
450
    def eot_token_id(self) -> int:
451
452
453
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

454
    @property
Baber Abbasi's avatar
Baber Abbasi committed
455
    def prefix_token_id(self) -> int:
456
457
458
459
460
461
462
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

463
    @property
Baber Abbasi's avatar
Baber Abbasi committed
464
    def max_length(self) -> int:
465
466
467
468
469
470
471
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
Baber Abbasi's avatar
Baber Abbasi committed
472
            if self.tokenizer.model_max_length == TOKENIZER_INFINITY:
473
474
475
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
476

477
    @property
Ethan Smith's avatar
Ethan Smith committed
478
    def max_gen_toks(self) -> int:
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

KonradSzafer's avatar
KonradSzafer committed
497
498
499
500
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

501
502
    def _get_backend(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
503
        config: transformers.PretrainedConfig | transformers.AutoConfig,
504
        backend: Literal["default", "causal", "seq2seq"] = "default",
Baber Abbasi's avatar
Baber Abbasi committed
505
        trust_remote_code: bool | None = False,
506
    ) -> None:
Baber Abbasi's avatar
Baber Abbasi committed
507
508
        """Helper method during initialization.

509
        Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
510
        sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
511
512
513

        **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
        user must set `self.backend` to be either "causal" or "seq2seq" manually!**
514
        """
515

516
517
518
519
        assert backend in ["default", "causal", "seq2seq"]

        if backend != "default":
            # if we've settled on non-default backend, use that manually
Baber Abbasi's avatar
Baber Abbasi committed
520
            if backend in ["causal", "seq2seq"]:
521
                self.backend = backend
522
            eval_logger.info(
523
                f"Overrode HF model backend type, and using type '{self.backend}'"
524
525
526
527
            )
        else:
            # determine and use the default HF backend for this model, based on its config + metadata.
            if (
Baber Abbasi's avatar
Baber Abbasi committed
528
                getattr(config, "model_type", None)
529
530
531
532
533
                in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
            ):
                # first check if model type is listed under seq2seq models, since some
                # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
                # these special cases should be treated as seq2seq models.
534
                self.backend = "seq2seq"
535
                eval_logger.debug(f"Using model type '{self.backend}'")
536
            elif (
Baber Abbasi's avatar
Baber Abbasi committed
537
                getattr(config, "model_type", None) in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
538
            ):
539
                self.backend = "causal"
540
                eval_logger.debug(f"Using model type '{self.backend}'")
541
542
543
544
545
            else:
                if not trust_remote_code:
                    eval_logger.warning(
                        "HF model type is neither marked as CausalLM or Seq2SeqLM. \
                    This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
546
                        "Setting backend to causal"
547
548
                    )
                # if model type is neither in HF transformers causal or seq2seq model registries
549
550
551
                # then we default to assuming AutoModelForCausalLM
                self.backend = "causal"
                eval_logger.info(
552
                    f"Model type cannot be determined. Using default model type '{self.backend}'"
553
                )
554

555
556
557
558
559
        if self.AUTO_MODEL_CLASS is None:
            if self.backend == "causal":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            elif self.backend == "seq2seq":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
560
561
562
563
564
565

    def _get_config(
        self,
        pretrained: str,
        revision: str = "main",
        trust_remote_code: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
566
        gguf_file: str | None = None,
567
        subfolder: str = "",
568
    ) -> None:
Baber Abbasi's avatar
Baber Abbasi committed
569
        """Return the model config for HuggingFace models."""
570
571
572
573
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
574
            gguf_file=gguf_file,
575
            subfolder=subfolder,
576
577
578
579
580
        )

    def _create_model(
        self,
        pretrained: str,
Baber Abbasi's avatar
Baber Abbasi committed
581
582
583
        revision: str | None = "main",
        dtype: str | torch.dtype | None = "auto",
        trust_remote_code: bool | None = False,
584
585
586
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # (accelerate naive PP (device_map) options)
Baber Abbasi's avatar
Baber Abbasi committed
587
588
589
590
591
        parallelize: bool | None = False,
        gpus: int | None = None,
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | None = "./offload",
592
        # PEFT, delta weights and quantization options
Baber Abbasi's avatar
Baber Abbasi committed
593
594
595
596
597
598
        peft: str | None = None,
        delta: str | None = None,
        autogptq: bool | str | None = False,
        gptqmodel: bool | None = False,
        gguf_file: str | None = None,
        quantization_config: AutoQuantizationConfig | None = None,
599
        subfolder: str = "",
600
601
        **kwargs,
    ) -> None:
Baber Abbasi's avatar
Baber Abbasi committed
602
        """Initializes an HF or HF-compatible PreTrainedModel from scratch
603
604
605
606
607
608
609
610
611
612
        inside HFLM, using the kwargs passed into self.__init__().

        Also handles functionality such as AutoGPTQ usage and PEFT wrapping.

        For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
        (such as PyTorch models that are nearly, but not quite, fully mirroring
        HF's public interface relied on in this HFLM class)
        please consider subclassing HFLM and overriding this and other methods as needed.
        """

Baber Abbasi's avatar
Baber Abbasi committed
613
        model_kwargs = kwargs or {}
614

Nathan Habib's avatar
Nathan Habib committed
615
616
617
        model_kwargs.update(
            self._get_accelerate_args(
                parallelize=parallelize,
Baber Abbasi's avatar
Baber Abbasi committed
618
                device_map=kwargs.get("device_map"),
Nathan Habib's avatar
Nathan Habib committed
619
620
621
622
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                gpus=gpus,
623
            )
Nathan Habib's avatar
Nathan Habib committed
624
        )
625

626
        if not autogptq and not gptqmodel:
Baber Abbasi's avatar
Baber Abbasi committed
627
628
            if model_kwargs.get("load_in_4bit"):
                assert vparse(transformers.__version__) >= vparse("4.30.0"), (
Baber Abbasi's avatar
Baber Abbasi committed
629
630
                    "load_in_4bit requires transformers >= 4.30.0"
                )
Baber Abbasi's avatar
Baber Abbasi committed
631
632
                if compute_dtype := model_kwargs.get("bnb_4bit_compute_dtype"):
                    model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(compute_dtype)
Nathan Habib's avatar
Nathan Habib committed
633

634
635
636
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
637
                torch_dtype=get_dtype(dtype),
638
                trust_remote_code=trust_remote_code,
639
                gguf_file=gguf_file,
640
                quantization_config=quantization_config,
641
                subfolder=subfolder,
642
643
644
                **model_kwargs,
            )
        else:
645
646
647
            if autogptq and gptqmodel:
                raise ValueError(
                    "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
648
649
                )

650
651
652
653
654
655
656
            if autogptq:
                try:
                    from auto_gptq import AutoGPTQForCausalLM
                except ModuleNotFoundError as exception:
                    raise type(exception)(
                        "Tried to load auto_gptq, but auto-gptq is not installed ",
                        "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
Baber Abbasi's avatar
Baber Abbasi committed
657
                    ) from exception
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675

                self._model = AutoGPTQForCausalLM.from_quantized(
                    pretrained,
                    trust_remote_code=trust_remote_code,
                    model_basename=None if autogptq is True else Path(autogptq).stem,
                    use_safetensors=True
                    if autogptq is True
                    else autogptq.endswith(".safetensors"),
                    **model_kwargs,
                )

            if gptqmodel:
                try:
                    from gptqmodel import GPTQModel
                except ModuleNotFoundError as exception:
                    raise type(exception)(
                        "Tried to load gptqmodel, but gptqmodel is not installed ",
                        "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
Baber Abbasi's avatar
Baber Abbasi committed
676
                    ) from exception
677
678
679
680

                self._model = GPTQModel.from_quantized(
                    pretrained, trust_remote_code=trust_remote_code, **model_kwargs
                )
681

682
683
684
685
686
        if peft and delta:
            raise ValueError(
                "Cannot use both 'peft' and 'delta' options at the same time."
            )

687
        if peft:
688
689
690
            from peft import PeftModel
            from peft import __version__ as PEFT_VERSION

Baber Abbasi's avatar
Baber Abbasi committed
691
692
693
694
            if model_kwargs.get("load_in_4bit") and vparse(PEFT_VERSION) < vparse(
                "0.4.0"
            ):
                raise AssertionError("load_in_4bit requires peft >= 0.4.0")
695
696

            # Compatible with Gemma3 (multimodal) and old models
Janna's avatar
Janna committed
697
698
699
            if hasattr(self._model.config, "text_config") and hasattr(
                self._model.config.text_config, "vocab_size"
            ):
700
701
702
                vocab_size = self._model.config.text_config.vocab_size
            else:
                vocab_size = self._model.config.vocab_size
Janna's avatar
Janna committed
703

704
            if vocab_size != len(self.tokenizer):
705
                # resize model for LoRAs with added tokens
706
                eval_logger.info(
707
                    f"Model config indicates vocab_size='{vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
708
                )
709
                self._model.resize_token_embeddings(len(self.tokenizer))
710
711
712
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        elif delta:
            if autogptq:
                eval_logger.warning(
                    "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
                )
            _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
                delta,
                revision=revision,
                torch_dtype=get_dtype(dtype),
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
            for name, param in self._model.state_dict().items():
                try:
                    param.data += _model_delta.state_dict()[name]
Baber Abbasi's avatar
Baber Abbasi committed
728
729
730
731
                except KeyError as e:
                    raise KeyError(
                        f"Delta model is missing weights for layer: {name}"
                    ) from e
732
733
734
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to add delta weights to layer {name}. Error: {e}"
Baber Abbasi's avatar
Baber Abbasi committed
735
                    ) from e
736
737

            del _model_delta
738
739
740

    def _create_tokenizer(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
741
742
743
744
745
746
747
748
749
750
751
        pretrained: str | transformers.PreTrainedModel,
        tokenizer: str
        | transformers.PreTrainedTokenizer
        | transformers.PreTrainedTokenizerFast
        | None,
        revision: str | None = "main",
        trust_remote_code: bool | None = False,
        use_fast_tokenizer: bool | None = True,
        gguf_file: str | None = None,
        add_bos_token: bool | None = False,
        subfolder: str | None = "",
752
    ) -> None:
Baber Abbasi's avatar
Baber Abbasi committed
753
        """Helper method during initialization.
754
755
756
757

        Create a tokenizer object corresponding to the correct
        tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
        """
758
759
760
761
762
763
        kwargs = {
            "revision": revision,
            "trust_remote_code": trust_remote_code,
        }

        # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
764
        if not tokenizer and gguf_file is not None:
765
766
767
            kwargs["gguf_file"] = gguf_file
        else:
            kwargs["use_fast"] = use_fast_tokenizer
768

769
770
771
        if add_bos_token:
            kwargs["add_bos_token"] = True

772
773
774
        if subfolder:
            kwargs["subfolder"] = subfolder

775
776
777
        if tokenizer:
            if isinstance(tokenizer, str):
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
778
                    tokenizer, **kwargs
779
780
781
                )
            else:
                assert isinstance(
Baber Abbasi's avatar
Baber Abbasi committed
782
783
784
785
786
787
                    tokenizer,
                    (
                        transformers.PreTrainedTokenizer,
                        transformers.PreTrainedTokenizerFast,
                    ),
                )
788
789
790
791
792
793
794
795
796
                self.tokenizer = tokenizer
        else:
            # Get tokenizer based on 'pretrained'
            if isinstance(pretrained, str):
                model_name = pretrained
            else:
                # get the HF hub name via accessor on model
                model_name = self.model.name_or_path
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
797
                model_name, **kwargs
798
799
            )

Baber Abbasi's avatar
Baber Abbasi committed
800
    def _detect_batch_size(self, requests: Sequence | None = None, pos: int = 0):
Benjamin Fattori's avatar
Benjamin Fattori committed
801
802
803
804
805
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
806
807
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
Benjamin Fattori's avatar
Benjamin Fattori committed
808
809
        else:
            max_length = self.max_length
810
811
            max_context_enc = max_length
            max_cont_enc = max_length
lintangsutawika's avatar
lintangsutawika committed
812

Benjamin Fattori's avatar
Benjamin Fattori committed
813
814
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
Baber Abbasi's avatar
Baber Abbasi committed
815
        def forward_batch(batch_size: int):
816
            if self.backend == "seq2seq":
817
                length = max(max_context_enc, max_cont_enc)
lintangsutawika's avatar
lintangsutawika committed
818
819
820
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
821
822
                test_batch = torch.ones((batch_size, length), device=self.device).long()
                call_kwargs = {
lintangsutawika's avatar
lintangsutawika committed
823
824
825
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
826
827
            else:
                call_kwargs = {}
lintangsutawika's avatar
lintangsutawika committed
828
829
830
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
Benjamin Fattori's avatar
Benjamin Fattori committed
831
            for _ in range(5):
832
833
834
835
836
                out = F.log_softmax(  # noqa: F841
                    self._model_call(test_batch, **call_kwargs),
                    dim=-1,
                    dtype=self.softmax_dtype,
                )
lintangsutawika's avatar
lintangsutawika committed
837

Benjamin Fattori's avatar
Benjamin Fattori committed
838
839
            return batch_size

840
841
842
843
844
845
846
        try:
            batch_size = forward_batch()
        except RuntimeError as e:
            if "No executable batch size found" in str(e):
                batch_size = 1
            else:
                raise
Benjamin Fattori's avatar
Benjamin Fattori committed
847

848
849
850
851
852
853
854
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
855
            clear_torch_cache()
856
857
            return batch_size

858
        clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
859
860
        return batch_size

baberabb's avatar
baberabb committed
861
    def tok_encode(
Baber Abbasi's avatar
Baber Abbasi committed
862
863
864
        self,
        string: str,
        add_special_tokens: bool | None = None,
865
866
        left_truncate_len: int | None = None,
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
867
    ) -> list[int]:
Lintang Sutawika's avatar
Lintang Sutawika committed
868
869
        # default for None - empty dict, use predefined tokenizer param
        # used for all models except for CausalLM or predefined value
870
871
872
873
874
875
        special_tokens_kwargs: dict = (
            {
                "add_special_tokens": self.add_bos_token
                if add_special_tokens is None
                else add_special_tokens
            }
876
877
878
            if self.backend == "causal"
            # otherwise the method explicitly defines the value
            else {"add_special_tokens": add_special_tokens}
879
880
            if isinstance(add_special_tokens, bool)
            else {}
881
        )
882

Lintang Sutawika's avatar
Lintang Sutawika committed
883
        encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
haileyschoelkopf's avatar
haileyschoelkopf committed
884

885
886
887
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
888

889
890
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
891
    def tok_batch_encode(
lintangsutawika's avatar
lintangsutawika committed
892
        self,
Baber Abbasi's avatar
Baber Abbasi committed
893
        strings: list[str],
lintangsutawika's avatar
lintangsutawika committed
894
        padding_side: str = "left",
Baber Abbasi's avatar
Baber Abbasi committed
895
        left_truncate_len: int | None = None,
896
        truncation: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
897
    ) -> tuple[torch.Tensor, torch.Tensor]:
haileyschoelkopf's avatar
haileyschoelkopf committed
898
899
900
901
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

Lintang Sutawika's avatar
Lintang Sutawika committed
902
        add_special_tokens = {}
903
        if self.backend == "causal":
Lintang Sutawika's avatar
Lintang Sutawika committed
904
            add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
haileyschoelkopf's avatar
haileyschoelkopf committed
905
906
907

        encoding = self.tokenizer(
            strings,
lintangsutawika's avatar
lintangsutawika committed
908
            truncation=truncation,
haileyschoelkopf's avatar
haileyschoelkopf committed
909
910
            padding="longest",
            return_tensors="pt",
Lintang Sutawika's avatar
Lintang Sutawika committed
911
            **add_special_tokens,
haileyschoelkopf's avatar
haileyschoelkopf committed
912
913
        )
        if left_truncate_len:
914
915
            original_lengths = encoding["input_ids"].size(1)
            if original_lengths > left_truncate_len:
Baber Abbasi's avatar
Baber Abbasi committed
916
                eval_logger.warning(
917
918
919
                    f"Left truncation applied. Original sequence length was {original_lengths}, "
                    f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
                )
haileyschoelkopf's avatar
haileyschoelkopf committed
920
921
922
923
924
925
926
927
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

Baber Abbasi's avatar
Baber Abbasi committed
928
    def tok_decode(self, tokens: Iterator[list[str]], skip_special_tokens: bool = True):
Lintang Sutawika's avatar
Lintang Sutawika committed
929
        return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
930

Baber Abbasi's avatar
Baber Abbasi committed
931
932
933
934
935
936
    def _model_call(
        self,
        inps: torch.Tensor,
        attn_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
937
        """
Baber Abbasi's avatar
Baber Abbasi committed
938

haileyschoelkopf's avatar
haileyschoelkopf committed
939
        :param inps: torch.Tensor
940
941
942
943
944
945
946
947
948
949
950
951
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
Baber Abbasi's avatar
Baber Abbasi committed
952
953
954
        with (
            torch.no_grad(),
            torch.autocast(
955
956
957
                device_type=self.device.type,
                dtype=self.mixed_precision_dtype,
                enabled=self.mixed_precision_dtype is not None,
Baber Abbasi's avatar
Baber Abbasi committed
958
959
960
961
962
963
964
965
966
967
968
969
970
971
            ),
        ):
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
                assert transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits

            assert self.AUTO_MODEL_CLASS in (
                transformers.AutoModelForCausalLM,
                transformers.AutoModelForVision2Seq,
            )
            return self.model(inps).logits
972

Baber Abbasi's avatar
Baber Abbasi committed
973
974
975
976
977
    def _model_generate(
        self,
        context,
        max_length: int,
        stop: list[str],
978
        **generation_kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
979
    ) -> torch.Tensor:
Baber Abbasi's avatar
Baber Abbasi committed
980
        # temperature = 0.0 if not set
981
982
983
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
Baber Abbasi's avatar
Baber Abbasi committed
984
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
Baber Abbasi's avatar
Baber Abbasi committed
985
        do_sample = generation_kwargs.get("do_sample")
986
987
988
989
990

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

Baber Abbasi's avatar
Baber Abbasi committed
991
992
        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
993
994
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
995
            self.tokenizer, stop, context.shape[1], context.shape[0]
996
        )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        with torch.autocast(
            device_type=self.device.type,
            dtype=self.mixed_precision_dtype,
            enabled=self.mixed_precision_dtype is not None,
        ):
            return self.model.generate(
                input_ids=context,
                max_length=max_length,
                stopping_criteria=stopping_criteria,
                pad_token_id=self.tokenizer.pad_token_id,
                use_cache=True,
                **generation_kwargs,
            )
1010

Baber Abbasi's avatar
Baber Abbasi committed
1011
    def _select_cont_toks(
Baber Abbasi's avatar
Baber Abbasi committed
1012
1013
1014
1015
        self,
        logits: torch.Tensor,
        contlen: int | None = None,
        inplen: int | None = None,
Baber Abbasi's avatar
Baber Abbasi committed
1016
    ) -> torch.Tensor:
1017
        if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
1018
1019
1020
            assert contlen and inplen, (
                "Must pass input len and cont. len to select scored logits for causal LM"
            )
1021
1022
1023
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
1024
        elif self.backend == "seq2seq":
Baber Abbasi's avatar
Baber Abbasi committed
1025
1026
1027
            assert contlen and not inplen, (
                "Selecting scored logits for Seq2SeqLM requires only cont. len"
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
1028
            # only discard right-padding.
1029
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
1030
1031
            logits = logits[:contlen]

1032
1033
        return logits

1034
    def loglikelihood_rolling(
Baber Abbasi's avatar
Baber Abbasi committed
1035
1036
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[float]:
Benjamin Fattori's avatar
Benjamin Fattori committed
1037
1038
1039
1040
1041
1042
1043
1044
        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

1045
1046
1047
1048
1049
1050
1051
1052
1053
        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
1054
        ):
Baber Abbasi's avatar
Baber Abbasi committed
1055
            rolling_token_windows: list[tuple[list[int], list[int]]] = list(
1056
1057
1058
1059
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
1060
                        prefix_token=self.prefix_token_id,
1061
1062
1063
1064
1065
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
1066
1067

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
1068
            windows = [(None,) + x for x in rolling_token_windows]
1069

1070
1071
1072
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        # Handle distributed case padding
        pad_amnt = 0
        if self.world_size > 1:
            mytensor = torch.tensor(len(all_windows), device=self.device)
            gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
            pad_amnt = max(gathered) - gathered[self.rank]
            if pad_amnt > 0:
                all_windows += pad_amnt * [all_windows[0]]

        all_nlls = []
        batch_size = adaptive_batch_size or self.batch_size
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)

            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
                override_bs=len(batch_windows),
1094
            )
1095
1096
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
1097

1098
1099
1100
        # Remove padding if necessary
        if (self.world_size > 1) and (pad_amnt > 0):
            all_nlls = all_nlls[:-pad_amnt]
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
1117

1118
        return loglikelihoods
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    def _batch_scheduler(self, pos, n_reordered_requests):
        sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
        if sched in self.batch_sizes:
            return self.batch_sizes[sched]
        if (len(self.batch_sizes) > 1) and (
            self.batch_sizes[sched - 1] == self.max_batch_size
        ):
            # if previous batch size is already maximal, skip recomputation
            self.batch_sizes[sched] = self.max_batch_size
            return self.batch_sizes[sched]
        print(
            f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
        )
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1133
        self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1134
1135
        print(f"Determined largest batch size: {self.batch_sizes[sched]}")
        return self.batch_sizes[sched]
1136

Ethan Smith's avatar
Ethan Smith committed
1137
    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
1138
        self,
Baber Abbasi's avatar
Baber Abbasi committed
1139
        requests: list[tuple[tuple[str, str], list[int], list[int]]],
baberabb's avatar
baberabb committed
1140
        disable_tqdm: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
1141
1142
        override_bs: int | None = None,
    ) -> list[tuple[float, bool]]:
1143
1144
1145
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

Baber Abbasi's avatar
Baber Abbasi committed
1146
1147
        def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
            """Defines the key for the sorted method."""
1148
1149
1150
1151
1152
1153
1154
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

Baber Abbasi's avatar
Baber Abbasi committed
1155
            toks = req[1] + req[2]
1156
1157
            return -len(toks), tuple(toks)

Baber Abbasi's avatar
Baber Abbasi committed
1158
1159
        def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):
            """Defines the key to group and lookup one-token continuations."""
Baber Abbasi's avatar
Baber Abbasi committed
1160
            # Use with group_by="contexts" (optional)"
Baber Abbasi's avatar
Baber Abbasi committed
1161
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
Baber Abbasi's avatar
Baber Abbasi committed
1162
1163
1164
1165
1166
1167
1168
1169
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by="contexts"
1170
            if self.backend == "causal" and self.logits_cache
Baber Abbasi's avatar
Baber Abbasi committed
1171
1172
1173
            else None,
            group_fn=_lookup_one_token_cont,
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
1174
1175
1176

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
Baber Abbasi's avatar
Baber Abbasi committed
1177
1178
1179
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
1180
1181
1182
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
Baber Abbasi's avatar
Baber Abbasi committed
1183
1184
1185
1186
            else 0
        )
        batch_fn = (
            self._batch_scheduler
1187
1188
1189
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
Baber Abbasi's avatar
Baber Abbasi committed
1190
            else None
1191
1192
        )

Baber Abbasi's avatar
Baber Abbasi committed
1193
        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1194
1195
1196
1197
1198
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
1199
        for chunk in chunks:
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
1219
                # how this all works (illustrated on a causal decoder-only setup):
1220
1221
1222
1223
1224
1225
1226
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
1227
                if self.backend == "causal":
1228
1229
                    total_length = len(context_enc) + len(continuation_enc)
                    if total_length > self.max_length + 1:
1230
                        eval_logger.warning(
1231
1232
1233
1234
                            f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
                            f"exceeds model's maximum length ({self.max_length}). "
                            f"Truncating {total_length - self.max_length + 1} tokens from the left."
                        )
1235
1236
1237
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
1238
1239
                        device=self.device,
                    )
1240
                    (inplen,) = inp.shape
1241
                elif self.backend == "seq2seq":
1242
1243
1244
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
1245
                        device=self.device,
1246
                    )
1247
                    (inplen,) = inp.shape
1248
1249
1250
1251

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

1252
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
1253
                        (continuation_enc)[-self.max_length :],
1254
1255
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
1256
                        dtype=torch.long,
1257
1258
                        device=self.device,
                    )
1259
1260
                    (contlen,) = cont.shape

1261
1262
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
1263
1264
1265
1266
1267
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
1268

haileyschoelkopf's avatar
haileyschoelkopf committed
1269
1270
1271
1272
1273
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
1274
1275
1276
1277

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
1278

1279
1280
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
1281
            if self.backend == "causal":
1282
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1283
1284
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
1285
            elif self.backend == "seq2seq":
1286
                # TODO: left-pad encoder inps and mask?
1287
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1288
1289
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
1290
                batched_conts = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1291
1292
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
1293
                batched_encoder_mask = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1294
1295
1296
1297
1298
1299
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
1300
1301

            multi_logits = F.log_softmax(
1302
1303
1304
                self._model_call(batched_inps, **call_kwargs),
                dim=-1,
                dtype=self.softmax_dtype,
1305
            )  # [batch, padding_length (inp or cont), vocab]
1306

Baber Abbasi's avatar
Baber Abbasi committed
1307
            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1308
1309
1310
1311
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
1312
                # take only logits in the continuation
1313
                # (discard context toks if decoder-only ; discard right-padding)
1314
1315
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
haileyschoelkopf's avatar
haileyschoelkopf committed
1316
                ctx_len = (
1317
                    inplen + (logits.shape[0] - padding_len_inp)
1318
                    if self.backend == "causal"
haileyschoelkopf's avatar
haileyschoelkopf committed
1319
1320
                    else None
                )
1321
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
1322
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
1323
1324
1325
1326

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

Baber Abbasi's avatar
Baber Abbasi committed
1327
1328
1329
1330
1331
                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
Baber Abbasi's avatar
Baber Abbasi committed
1332
                for request_str, cont_toks, logits in re_ord.get_cache(  # noqa
Baber Abbasi's avatar
Baber Abbasi committed
1333
1334
1335
1336
1337
1338
1339
1340
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(0)  # [1, seq]
1341
1342
1343
1344
1345
1346
                    # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
                    # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
                    # by choosing key with longest cont if group_by="contexts".
                    max_equal = (
                        greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
                    ).all()
Baber Abbasi's avatar
Baber Abbasi committed
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    res.append(answer)

1359
1360
1361
1362
1363
1364
1365
                    if request_str is not None:
                        # special case: loglikelihood_rolling produces a number of loglikelihood requests
                        # all with cache key None. instead do add_partial on the per-example level
                        # in the loglikelihood_rolling() function for those.
                        self.cache_hook.add_partial(
                            "loglikelihood", request_str, answer
                        )
Baber Abbasi's avatar
Baber Abbasi committed
1366
                    pbar.update(1)
haileyschoelkopf's avatar
haileyschoelkopf committed
1367
1368

        pbar.close()
haileyschoelkopf's avatar
haileyschoelkopf committed
1369

1370
1371
        return re_ord.get_original(res)

1372
    def generate_until(
Baber Abbasi's avatar
Baber Abbasi committed
1373
1374
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[str]:
Baber Abbasi's avatar
Baber Abbasi committed
1375
        res = []
1376

Baber Abbasi's avatar
Baber Abbasi committed
1377
        def _collate(req: tuple[str, dict]):
Baber Abbasi's avatar
Baber Abbasi committed
1378
            """Defines the key for the sorted method"""
1379
1380
1381
1382
1383
1384
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
Baber Abbasi's avatar
Baber Abbasi committed
1385
1386
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]
1387

1388
1389
        pbar = tqdm(
            total=len(requests),
1390
            disable=(disable_tqdm or (self.rank != 0)),
1391
1392
            desc="Running generate_until requests",
        )
Baber Abbasi's avatar
Baber Abbasi committed
1393
        adaptive_batch_size = None
1394
1395
1396
1397
1398
1399
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
1400
        # for each different set of kwargs, we execute all requests, by batch.
Baber Abbasi's avatar
Baber Abbasi committed
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )
1413

Baber Abbasi's avatar
Baber Abbasi committed
1414
1415
1416
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
1417
1418
1419
1420
1421
1422
1423
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            [reg.args for reg in requests],
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
Baber Abbasi's avatar
Baber Abbasi committed
1424
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1425
        eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
Baber Abbasi's avatar
Baber Abbasi committed
1426
1427
1428
1429
1430
1431
1432
1433
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
1434
1435
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
Baber Abbasi's avatar
Baber Abbasi committed
1436
            else:
Baber Abbasi's avatar
Baber Abbasi committed
1437
                raise TypeError(
Baber Abbasi's avatar
Baber Abbasi committed
1438
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1439
                )
Baber Abbasi's avatar
Baber Abbasi committed
1440
            if "max_gen_toks" in kwargs:
Baber Abbasi's avatar
Baber Abbasi committed
1441
1442
1443
1444
1445
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
1446
            if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
1447
1448
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
Baber Abbasi's avatar
Baber Abbasi committed
1449
1450
1451
                assert max_ctx_len > 0, (
                    f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
                )
1452
            elif self.backend == "seq2seq":
Baber Abbasi's avatar
Baber Abbasi committed
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)
1464

Baber Abbasi's avatar
Baber Abbasi committed
1465
1466
            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1467

Baber Abbasi's avatar
Baber Abbasi committed
1468
1469
1470
1471
1472
1473
1474
            # perform batched generation
            cont = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
1475

Baber Abbasi's avatar
Baber Abbasi committed
1476
1477
1478
            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
1479
                if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
1480
                    cont_toks = cont_toks[context_enc.shape[1] :]
1481

1482
1483
1484
1485
1486
1487
1488
1489
1490
                # Handle integer think_end_token: find last occurrence and strip tokens after it
                if isinstance(self.think_end_token, int):
                    think_token_indices = [
                        i
                        for i, token in enumerate(cont_toks)
                        if token == self.think_end_token
                    ]
                    if think_token_indices:
                        cont_toks = cont_toks[think_token_indices[-1] + 1 :]
1491

1492
                s = self.tok_decode(cont_toks)
Baber Abbasi's avatar
Baber Abbasi committed
1493

1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
                # Strip leading whitespace if we removed thinking tokens
                if isinstance(self.think_end_token, int):
                    s = s.lstrip()

                # Apply post-processing: remove stop sequences and string-based thinking tokens
                s = postprocess_generated_text(
                    generation=s,
                    stop=until,
                    think_end_token=self.think_end_token
                    if isinstance(self.think_end_token, str)
                    else None,
                )
Baber Abbasi's avatar
Baber Abbasi committed
1506
1507
1508
1509
1510
1511
                res.append(s)

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
1512

1513
        pbar.close()
1514

Baber Abbasi's avatar
Baber Abbasi committed
1515
        return res
1516

Baber Abbasi's avatar
Baber Abbasi committed
1517
    def apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
1518
        self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
Baber Abbasi's avatar
Baber Abbasi committed
1519
    ) -> str:
Baber Abbasi's avatar
Baber Abbasi committed
1520
        """Method to apply a chat template to a list of chat history between user and model."""
1521
1522
        try:
            chat_templated = self.tokenizer.apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
1523
1524
1525
1526
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
1527
                **self.chat_template_args,
1528
1529
1530
1531
1532
1533
1534
            )
        except jinja2.exceptions.TemplateError:
            eval_logger.warning(
                "Failed to apply chat template. removing the system role in chat history."
            )
            chat_history = [msg for msg in chat_history if msg["role"] != "system"]
            chat_templated = self.tokenizer.apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
1535
1536
1537
1538
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
1539
                **self.chat_template_args,
1540
1541
1542
            )

        return chat_templated
KonradSzafer's avatar
KonradSzafer committed
1543

1544
    def get_model_info(self) -> dict:
Baber Abbasi's avatar
Baber Abbasi committed
1545
        """Method to get Hugging Face model information for experiment reproducibility."""
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565

        def get_model_num_params(model) -> int:
            if hasattr(model, "num_parameters"):
                return model.num_parameters()
            if hasattr(model, "parameters"):
                return sum(p.numel() for p in model.parameters())
            else:
                return -1

        def get_model_dtype(model) -> str:
            if hasattr(model, "dtype"):
                return model.dtype
            else:
                return ""

        def get_model_sha(pretrained: str, revision: str) -> str:
            try:
                model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
                return model_info.sha
            except Exception as e:
Baber Abbasi's avatar
Baber Abbasi committed
1566
                eval_logger.debug(
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
                    f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
                )
                return ""

        model_info = {
            "model_num_parameters": get_model_num_params(self._model),
            "model_dtype": get_model_dtype(self._model),
            "model_revision": self.revision,
            "model_sha": get_model_sha(self.pretrained, self.revision),
        }
        if self.peft:
            model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
        if self.delta:
            model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
        return model_info