huggingface.py 60.8 KB
Newer Older
1
import copy
2
import os
Jeevan's avatar
Jeevan committed
3
from datetime import timedelta
4
from pathlib import Path
KonradSzafer's avatar
KonradSzafer committed
5
from typing import Dict, List, Literal, Optional, Tuple, Union
6

7
import torch
8
import torch.nn.functional as F
9
import transformers
Jeevan's avatar
Jeevan committed
10
11
12
13
14
from accelerate import (
    Accelerator,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)
Nathan Habib's avatar
Nathan Habib committed
15
from accelerate.utils import get_max_memory
16
from huggingface_hub import HfApi
17
18
19
20
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
21
22
23
24
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
25
26

from lm_eval import utils
baberabb's avatar
baberabb committed
27
from lm_eval.api.instance import Instance
28
from lm_eval.api.model import TemplateLM
29
from lm_eval.api.registry import register_model
30
31
32
from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
33
    configure_pad_token,
34
35
36
37
    get_dtype,
    pad_and_concat,
    stop_sequences_criteria,
)
38

39

40
eval_logger = utils.eval_logger
41

lintangsutawika's avatar
lintangsutawika committed
42

43
@register_model("hf-auto", "hf", "huggingface")
44
class HFLM(TemplateLM):
45
46
47
48
49
50
51
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

52
    AUTO_MODEL_CLASS = None
53
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
54

55
56
    def __init__(
        self,
57
        pretrained: Union[str, transformers.PreTrainedModel],
Baber Abbasi's avatar
Baber Abbasi committed
58
59
        backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
60
61
        revision: Optional[str] = "main",
        subfolder: Optional[str] = None,
62
63
64
65
66
67
68
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ] = None,
lintangsutawika's avatar
lintangsutawika committed
69
        truncation: Optional[bool] = False,
Baber Abbasi's avatar
Baber Abbasi committed
70
        logits_cache: bool = True,
71
72
        max_length: Optional[int] = None,
        device: Optional[str] = "cuda",
73
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Benjamin Fattori's avatar
Benjamin Fattori committed
74
75
        batch_size: Optional[Union[int, str]] = 1,
        max_batch_size: Optional[int] = 64,
76
        trust_remote_code: Optional[bool] = False,
haileyschoelkopf's avatar
haileyschoelkopf committed
77
        use_fast_tokenizer: Optional[bool] = True,
78
        add_bos_token: Optional[bool] = False,
79
        prefix_token_id: Optional[int] = None,
80
        # arguments used for splitting a model across GPUs naively.
81
82
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
83
84
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
85
        offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
86
        # PEFT, delta weights and quantization options
87
        peft: Optional[str] = None,
88
        delta: Optional[str] = None,
89
90
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
Ethan Smith's avatar
Ethan Smith committed
91
    ) -> None:
92
93
        super().__init__()

94
95
96
97
        # optionally: take in an already-initialized transformers.PreTrainedModel
        if not isinstance(pretrained, str):
            eval_logger.warning(
                "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
98
            )
99
            assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
100
101
102
            self._model = pretrained
            self._device = self._model.device
            self._config = self._model.config
Baber Abbasi's avatar
Baber Abbasi committed
103
            gpus = 0
104

105
        else:
106
107
108
109
110
            assert isinstance(device, str)
            assert isinstance(pretrained, str)
            assert isinstance(batch_size, (int, str))

            gpus = torch.cuda.device_count()
Jeevan's avatar
Jeevan committed
111
112
            accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
            accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
113
114
            if accelerator.num_processes > 1:
                self.accelerator = accelerator
115

116
117
118
            if "npu" in accelerator.device.type:
                gpus = torch.npu.device_count()

Nathan Habib's avatar
Nathan Habib committed
119
            # using one process with no model parallelism
120
121
122
123
            if not (parallelize or accelerator.num_processes > 1):
                # use user-passed device
                device_list = set(
                    ["cuda", "cpu"]
124
                    + [f"cuda:{i}" for i in range(gpus)]
125
                    + ["mps", "mps:0"]
126
                    + [f"npu:{i}" for i in range(gpus)]
127
                )
128
                if device and device in device_list:
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                    self._device = torch.device(device)
                    eval_logger.info(f"Using device '{device}'")
                    if device in ("mps", "mps:0") and version.parse(
                        torch.__version__
                    ) < version.parse("2.1"):
                        raise RuntimeError(
                            f"mps requires torch >= 2.1. You have {torch.__version__}"
                        )
                else:
                    eval_logger.info("Device not specified")
                    eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                    self._device = (
                        torch.device("cuda")
                        if torch.cuda.is_available()
                        else torch.device("cpu")
                    )
Nathan Habib's avatar
Nathan Habib committed
145
            else:  # Parallelism managed by accelerate
146
147
148
149
150
                if device != "cuda":
                    eval_logger.info(
                        f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                    )
                # TODO: include in warning that `load_in_8bit` etc. affect this too
Nathan Habib's avatar
Nathan Habib committed
151
152
153
154
155
                self._device = (
                    self.accelerator.device
                    if hasattr(self, "accelerator")
                    else torch.device(device)
                )
156

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
157
            revision = str(revision)  # cast to string if not already one
158
159
            # TODO: update this to be less of a hack once subfolder is fixed in HF
            revision = revision + ("/" + subfolder if subfolder is not None else "")
160

161
            self._get_config(
162
163
164
165
166
                pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
            )

167
168
169
170
        # determine which of 'causal' and 'seq2seq' backends to use
        self._get_backend(
            config=self.config, backend=backend, trust_remote_code=trust_remote_code
        )
171

172
173
174
175
176
177
178
179
180
        # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
        self._create_tokenizer(
            pretrained,
            tokenizer,
            revision=revision,
            trust_remote_code=trust_remote_code,
            use_fast_tokenizer=use_fast_tokenizer,
        )

181
182
183
184
185
186
187
188
        # if we passed `pretrained` as a string, initialize our model now
        if isinstance(pretrained, str):
            self._create_model(
                pretrained=pretrained,
                revision=revision,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                parallelize=parallelize,
189
                gpus=gpus,
190
191
192
193
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                peft=peft,
194
                delta=delta,
195
196
                autogptq=autogptq,
                **kwargs,
197
198
            )

199
        # access self._model through self.model property outside this method
200
201
202
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            self.model.tie_weights()
haileyschoelkopf's avatar
haileyschoelkopf committed
203

lintangsutawika's avatar
lintangsutawika committed
204
        self.truncation = truncation
Baber Abbasi's avatar
Baber Abbasi committed
205
        self.logits_cache = logits_cache
206
        self.vocab_size = self.tokenizer.vocab_size
207
        # select (or create) a pad token to use
208
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
209

210
        self.add_bos_token = add_bos_token
211
        if "gemma" in getattr(self.config, "model_type", ""):
212
            self.add_bos_token = True
213
            eval_logger.info(
214
                f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
215
216
            )

217
        self._max_length = max_length
218
219
220
221
        self.pretrained = pretrained
        self.delta = delta
        self.peft = peft
        self.revision = revision
Benjamin Fattori's avatar
Benjamin Fattori committed
222
223
224
225
226
227
228
229
230
231
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
232

233
        if isinstance(pretrained, str):
Nathan Habib's avatar
Nathan Habib committed
234
235
236
237
238
239
240
241
242
243
244
245
            if gpus >= 1 or str(self.device) == "mps":
                # TODO: can remove this whole snippet except in the mps case, perhaps?
                if not (parallelize or autogptq or hasattr(self, "accelerator")):
                    # place model onto device requested manually,
                    # if not using HF Accelerate or device_map
                    # or any other option that preloads model onto device
                    try:
                        self.model.to(self.device)
                    except ValueError:
                        eval_logger.debug(
                            "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
                        )
246
247
            # multigpu data-parallel support when launched with accelerate
            if gpus > 1:
Nathan Habib's avatar
Nathan Habib committed
248
249
250
251
                if accelerator.num_processes > 1:
                    if parallelize:
                        eval_logger.warning(
                            "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
252
                        )
Nathan Habib's avatar
Nathan Habib committed
253
                    elif gpus > accelerator.num_processes:
254
255
256
257
258
259
                        eval_logger.warning(
                            "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                            "If you would like to use data parallelism, please launch the script "
                            "with 'accelerate launch *script*'. "
                            f"Current run will proceed with {accelerator.num_processes} devices."
                        )
Nathan Habib's avatar
Nathan Habib committed
260
261
262
263
264
                        if self.accelerator.is_local_main_process:
                            eval_logger.info(
                                f"Using {gpus} devices with data parallelism"
                            )

265
                    self._device = torch.device(f"{accelerator.device}")
266
                    self.accelerator = accelerator
267

268
269
                    self._rank = self.accelerator.local_process_index
                    self._world_size = self.accelerator.num_processes
Nathan Habib's avatar
Nathan Habib committed
270
271
272
273
                else:
                    # if we aren't launching via accelerate, ditch
                    self._rank = 0
                    self._world_size = 1
274
275
276
277
278
279
280
        else:
            # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
            eval_logger.warning(
                "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
            )
            self._rank = 0
            self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
281

282
        self.custom_prefix_token_id = prefix_token_id
283
284
285
286
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
287

Nathan Habib's avatar
Nathan Habib committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def _get_accelerate_args(
        self,
        parallelize: bool = None,
        device_map: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
        gpus: Optional[int] = None,
    ) -> dict:
        """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
        num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
        num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
        if (
            num_machines == 0
            and hasattr(self, "accelerator")
            and self.accelerator is not None
        ):
            eval_logger.info(
                "We are not in a distributed setting for accelerate. Setting model_parallel to False."
            )
            parallelize = False

        if parallelize is None:
            # If parallelism is unset by the user, we automatically assign model parallelism
            # if enough extra GPUs are available
            max_memory_all_gpus = get_max_memory()
            # We just want gpu, not cpu, max memory
            if "cpu" in max_memory_all_gpus:
                del max_memory_all_gpus["cpu"]
            parallelize = bool(num_local_processes < len(max_memory_all_gpus))
            eval_logger.info(
                f"Setting model parallel to {parallelize} since "
                f"the number of local processes is {num_local_processes} "
                f"and the number of GPUs is {len(max_memory_all_gpus)}"
            )

        args = {}
        if parallelize:  # Model parallelism will be used
            max_memory = {}
            if max_memory_per_gpu is not None:  # Using the provided memory requirements
                max_memory_per_gpu_map = {
                    device_idx: max_memory_per_gpu for device_idx in range(gpus)
                }
            else:  # Estimating the possible memory requirements
                max_memory_all_gpus = get_max_memory()
                if "cpu" in max_memory_all_gpus:
                    del max_memory_all_gpus["cpu"]
                if not hasattr(self, "accelerator"):
                    max_memory_per_gpu_map = {
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
337
                        k: v for k, v in max_memory_all_gpus.items()
Nathan Habib's avatar
Nathan Habib committed
338
                    }
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
339
                else:
Nathan Habib's avatar
Nathan Habib committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                    # use only 1 / num_processes of the GPUs if we are running under accelerate launch
                    max_memory_per_gpu_map = {
                        k: v
                        for k, v in max_memory_all_gpus.items()
                        if k % num_local_processes
                        == (self.accelerator.process_index % num_local_processes)
                    }
            args["max_memory"] = max_memory_per_gpu_map
            args["device_map"] = "auto"
            eval_logger.info(
                f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'"
            )

            if max_cpu_memory is not None:
                max_memory["cpu"] = max_cpu_memory

            args["offload_folder"] = offload_folder
        elif (
            device_map is None
        ):  # No model parallelism, we use the default provided device for our model
            if hasattr(self, "accelerator"):
                device_map = {"": f"{self.accelerator.device}"}
            else:
                device_map = {"": str(self.device)}
            args["max_memory"] = None
            args["device_map"] = device_map
            eval_logger.info(
                f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
            )
        else:
            args["max_memory"] = None
            args["device_map"] = None
            eval_logger.info("Model parallel was set to False.")

        return args

376
377
378
379
380
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

381
382
383
384
385
386
387
388
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

389
390
391
392
393
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

394
395
396
397
398
399
400
401
402
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

403
404
    @property
    def max_length(self):
405
406
407
408
409
410
411
412
413
414
415
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
416

417
    @property
Ethan Smith's avatar
Ethan Smith committed
418
    def max_gen_toks(self) -> int:
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

KonradSzafer's avatar
KonradSzafer committed
437
438
439
440
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
        """
        Get the appropriate chat template for the model based on configuration and input.
        This method determines, and returns the correct chat template, ensuring reproducibility.

        The template selection logic is adapted from the Transformers library's `apply_chat_template`
        method in the Tokenizer class. The original implementation can be found at:
        https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687

        This method ensures that the right template is chosen based on the following:
        1. If the model's tokenizer has multiple templates:
            a. Use the specified template if it exists in the dictionary.
            b. Use the default template from the list if no specific template is provided.
            c. Raise an error if no default template exists and no specific template is provided.
        2. If the model's tokenizer has a single template or no template:
            a. Use the tokenizer's chat template if available.
            b. Fall back to the default chat template if no tokenizer chat template exists.

        Args:
            chat_template (Union[bool, str]): Specifies the chat template to use.
                - If False or None, no template is applied.
                - If True, the default or only available template is used.
                - If a string, the template with the matching name is used.

        Returns:
            Optional[str]: The selected chat template, or None if no template is applied.
        """
        if chat_template is False or chat_template is None:
            eval_logger.warning(
                "model.chat_template was called with the chat_template set to False or None. "
                "Therefore no chat template will be applied. Make sure this is an intended behavior."
            )
            return None

        # Convert boolean chat_template to None to ensure compatibility with the adapted logic
        if isinstance(chat_template, bool):
            chat_template = None
        using_default_template = False

        # First, handle the cases when the model has a dict of multiple templates
        template = self.tokenizer.chat_template or self.tokenizer.default_chat_template

        if isinstance(template, dict):
            using_default_dict = self.tokenizer.chat_template is None

            if chat_template is not None:
                if chat_template in template:
                    selected_template = template[chat_template]
                    if using_default_dict:
                        using_default_template = True
                else:
                    raise ValueError(
                        f"The specified chat template '{chat_template}' is not available. "
                        f"Available template names are {sorted(template.keys())}."
                    )
            else:
                # If user didn't pass a chat template, use the default template from the dict
                if "default" in template:
                    selected_template = template["default"]
                    using_default_template = True
                else:
                    raise ValueError(
                        "This model has multiple chat templates with no default specified! Please either pass a chat "
                        "template or the name of the template you wish to use to the `chat_template` argument. Available "
                        f"template names are {sorted(template.keys())}."
                    )

        # Cases when the model has a single template or no template
        else:
            # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
            if isinstance(chat_template, str):
                eval_logger.warning(
                    "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
                    "Using the tokenizer's chat template or the default template instead."
                )
            if self.tokenizer.chat_template is not None:
                selected_template = self.tokenizer.chat_template
            else:
                selected_template = self.tokenizer.default_chat_template
                using_default_template = True

        if using_default_template:
            eval_logger.warning(
                "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
                "very error-prone, because models are often trained with templates different from the class default! "
                "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
                "point any code depending on them will stop working. We recommend setting a valid chat template before "
                "then to ensure that this model continues working without issues."
            )

        return selected_template
KonradSzafer's avatar
KonradSzafer committed
532

533
534
    def _get_backend(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
535
        config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        trust_remote_code: Optional[bool] = False,
    ) -> None:
        """
        Helper method during initialization.
        Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
        model type to be used.
        """
        assert backend in ["default", "causal", "seq2seq"]

        if backend != "default":
            # if we've settled on non-default backend, use that manually
            if backend == "causal":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            elif backend == "seq2seq":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            eval_logger.info(
                f"Overrode HF model backend type, and using type '{backend}'"
            )
        else:
            # determine and use the default HF backend for this model, based on its config + metadata.
            if (
                getattr(config, "model_type")
                in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
            ):
                # first check if model type is listed under seq2seq models, since some
                # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
                # these special cases should be treated as seq2seq models.
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            elif (
                getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
            ):
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            else:
                if not trust_remote_code:
                    eval_logger.warning(
                        "HF model type is neither marked as CausalLM or Seq2SeqLM. \
                    This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
                    )
                # if model type is neither in HF transformers causal or seq2seq model registries
                # then we default to AutoModelForCausalLM
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
        return None

    def _get_config(
        self,
        pretrained: str,
        revision: str = "main",
        trust_remote_code: bool = False,
    ) -> None:
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

    def _create_model(
        self,
        pretrained: str,
        revision: Optional[str] = "main",
        dtype: Optional[Union[str, torch.dtype]] = "auto",
        trust_remote_code: Optional[bool] = False,
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # (accelerate naive PP (device_map) options)
        parallelize: Optional[bool] = False,
607
        gpus: Optional[int] = None,
608
609
610
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
611
        # PEFT, delta weights and quantization options
612
        peft: Optional[str] = None,
613
        delta: Optional[str] = None,
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
    ) -> None:
        """
        Initializes an HF or HF-compatible PreTrainedModel from scratch
        inside HFLM, using the kwargs passed into self.__init__().

        Also handles functionality such as AutoGPTQ usage and PEFT wrapping.

        For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
        (such as PyTorch models that are nearly, but not quite, fully mirroring
        HF's public interface relied on in this HFLM class)
        please consider subclassing HFLM and overriding this and other methods as needed.
        """

        model_kwargs = kwargs if kwargs else {}

Nathan Habib's avatar
Nathan Habib committed
631
632
633
634
635
636
637
638
        model_kwargs.update(
            self._get_accelerate_args(
                parallelize=parallelize,
                device_map=kwargs.get("device_map", None),
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                gpus=gpus,
639
            )
Nathan Habib's avatar
Nathan Habib committed
640
        )
641

642
643
644
645
646
647
648
649
        if not autogptq:
            if model_kwargs.get("load_in_4bit", None):
                assert (
                    transformers.__version__ >= "4.30.0"
                ), "load_in_4bit requires transformers >= 4.30.0"
            if transformers.__version__ >= "4.30.0":
                if model_kwargs.get("load_in_4bit", None):
                    if model_kwargs.get("bnb_4bit_compute_dtype", None):
650
                        model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
651
652
                            model_kwargs["bnb_4bit_compute_dtype"]
                        )
Nathan Habib's avatar
Nathan Habib committed
653

654
655
656
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
657
                torch_dtype=get_dtype(dtype),
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
        else:
            try:
                from auto_gptq import AutoGPTQForCausalLM
            except ModuleNotFoundError:
                raise Exception(
                    "Tried to load auto_gptq, but auto-gptq is not installed ",
                    "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
                )

            self._model = AutoGPTQForCausalLM.from_quantized(
                pretrained,
                trust_remote_code=trust_remote_code,
                model_basename=None if autogptq is True else Path(autogptq).stem,
                use_safetensors=True
                if autogptq is True
                else autogptq.endswith(".safetensors"),
                **model_kwargs,
            )

680
681
682
683
684
        if peft and delta:
            raise ValueError(
                "Cannot use both 'peft' and 'delta' options at the same time."
            )

685
686
        if peft:
            if model_kwargs.get("load_in_4bit", None):
WoosungMyung's avatar
WoosungMyung committed
687
688
                if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
                    raise AssertionError("load_in_4bit requires peft >= 0.4.0")
689
690
            if self._model.config.vocab_size != len(self.tokenizer):
                # resize model for LoRAs with added tokens
691
692
693
                eval_logger.info(
                    f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
                )
694
                self._model.resize_token_embeddings(len(self.tokenizer))
695
696
697
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        elif delta:
            if autogptq:
                eval_logger.warning(
                    "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
                )
            _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
                delta,
                revision=revision,
                torch_dtype=get_dtype(dtype),
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
            for name, param in self._model.state_dict().items():
                try:
                    param.data += _model_delta.state_dict()[name]
                except KeyError:
                    raise KeyError(f"Delta model is missing weights for layer: {name}")
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to add delta weights to layer {name}. Error: {e}"
                    )

            del _model_delta
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772

        return None

    def _create_tokenizer(
        self,
        pretrained: Union[str, transformers.PreTrainedModel],
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ],
        revision: Optional[str] = "main",
        trust_remote_code: Optional[bool] = False,
        use_fast_tokenizer: Optional[bool] = True,
    ) -> None:
        """
        Helper method during initialization.

        Create a tokenizer object corresponding to the correct
        tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
        """

        if tokenizer:
            if isinstance(tokenizer, str):
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    tokenizer,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    use_fast=use_fast_tokenizer,
                )
            else:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
        else:
            # Get tokenizer based on 'pretrained'
            if isinstance(pretrained, str):
                model_name = pretrained
            else:
                # get the HF hub name via accessor on model
                model_name = self.model.name_or_path
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                model_name,
                revision=revision,
                trust_remote_code=trust_remote_code,
                use_fast=use_fast_tokenizer,
            )
        return None

Ethan Smith's avatar
Ethan Smith committed
773
    def _detect_batch_size(self, requests=None, pos: int = 0):
Benjamin Fattori's avatar
Benjamin Fattori committed
774
775
776
777
778
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
779
780
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
Benjamin Fattori's avatar
Benjamin Fattori committed
781
782
        else:
            max_length = self.max_length
783
784
            max_context_enc = max_length
            max_cont_enc = max_length
lintangsutawika's avatar
lintangsutawika committed
785

Benjamin Fattori's avatar
Benjamin Fattori committed
786
787
788
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
789
790
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                length = max(max_context_enc, max_cont_enc)
lintangsutawika's avatar
lintangsutawika committed
791
792
793
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
794
795
                test_batch = torch.ones((batch_size, length), device=self.device).long()
                call_kwargs = {
lintangsutawika's avatar
lintangsutawika committed
796
797
798
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
799
800
            else:
                call_kwargs = {}
lintangsutawika's avatar
lintangsutawika committed
801
802
803
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
Benjamin Fattori's avatar
Benjamin Fattori committed
804
            for _ in range(5):
805
                out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)  # noqa: F841
lintangsutawika's avatar
lintangsutawika committed
806

Benjamin Fattori's avatar
Benjamin Fattori committed
807
808
            return batch_size

809
810
811
812
813
814
815
        try:
            batch_size = forward_batch()
        except RuntimeError as e:
            if "No executable batch size found" in str(e):
                batch_size = 1
            else:
                raise
Benjamin Fattori's avatar
Benjamin Fattori committed
816

817
818
819
820
821
822
823
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
824
            clear_torch_cache()
825
826
            return batch_size

827
        clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
828
829
        return batch_size

baberabb's avatar
baberabb committed
830
831
832
    def tok_encode(
        self, string: str, left_truncate_len=None, add_special_tokens=None
    ) -> List[int]:
haileyschoelkopf's avatar
haileyschoelkopf committed
833
        """ """
Lintang Sutawika's avatar
Lintang Sutawika committed
834
835
836
837
838
        # default for None - empty dict, use predefined tokenizer param
        # used for all models except for CausalLM or predefined value
        special_tokens_kwargs = {}

        # by default for CausalLM - false or self.add_bos_token is set
839
840
        if add_special_tokens is None:
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
Lintang Sutawika's avatar
Lintang Sutawika committed
841
842
843
844
845
846
                special_tokens_kwargs = {
                    "add_special_tokens": False or self.add_bos_token
                }
        # otherwise the method explicitly defines the value
        else:
            special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
847

Lintang Sutawika's avatar
Lintang Sutawika committed
848
        encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
haileyschoelkopf's avatar
haileyschoelkopf committed
849

850
851
852
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
853

854
855
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
856
    def tok_batch_encode(
lintangsutawika's avatar
lintangsutawika committed
857
858
        self,
        strings: List[str],
lintangsutawika's avatar
lintangsutawika committed
859
        padding_side: str = "left",
860
861
        left_truncate_len: int = None,
        truncation: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
862
    ) -> Tuple[torch.Tensor, torch.Tensor]:
haileyschoelkopf's avatar
haileyschoelkopf committed
863
864
865
866
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

Lintang Sutawika's avatar
Lintang Sutawika committed
867
        add_special_tokens = {}
haileyschoelkopf's avatar
haileyschoelkopf committed
868
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
Lintang Sutawika's avatar
Lintang Sutawika committed
869
            add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
haileyschoelkopf's avatar
haileyschoelkopf committed
870
871
872

        encoding = self.tokenizer(
            strings,
lintangsutawika's avatar
lintangsutawika committed
873
            truncation=truncation,
haileyschoelkopf's avatar
haileyschoelkopf committed
874
875
            padding="longest",
            return_tensors="pt",
Lintang Sutawika's avatar
Lintang Sutawika committed
876
            **add_special_tokens,
haileyschoelkopf's avatar
haileyschoelkopf committed
877
878
879
880
881
882
883
884
885
886
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

Lintang Sutawika's avatar
Lintang Sutawika committed
887
888
    def tok_decode(self, tokens, skip_special_tokens=True):
        return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
889
890
891

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
892
        :param inps: torch.Tensor
893
894
895
896
897
898
899
900
901
902
903
904
905
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
906
907
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
908
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
909
910
911
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
912
913
914
915
916
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
917
        # temperature = 0.0 if not set
918
919
920
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
Baber Abbasi's avatar
Baber Abbasi committed
921
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
922
        do_sample = generation_kwargs.get("do_sample", None)
923
924
925
926
927

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

Baber Abbasi's avatar
Baber Abbasi committed
928
929
        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
930
931
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
932
            self.tokenizer, stop, context.shape[1], context.shape[0]
933
        )
934
        return self.model.generate(
935
            input_ids=context,
936
937
            max_length=max_length,
            stopping_criteria=stopping_criteria,
938
            pad_token_id=self.tokenizer.pad_token_id,
939
940
941
            use_cache=True,
            **generation_kwargs,
        )
942

Baber Abbasi's avatar
Baber Abbasi committed
943
944
945
    def _select_cont_toks(
        self, logits: torch.Tensor, contlen: int = None, inplen: int = None
    ) -> torch.Tensor:
946
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
947
948
949
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
950
951
952
953
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
954
955
956
957
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
958
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
959
960
            logits = logits[:contlen]

961
962
        return logits

963
964
965
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
966
        loglikelihoods = []
Benjamin Fattori's avatar
Benjamin Fattori committed
967
968
969
970
971
972
973
974
975

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

976
977
978
        for (string,) in tqdm(
            [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
        ):
979
980
981
982
983
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
984
                        prefix_token=self.prefix_token_id,
985
986
987
988
989
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
990
991

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
Baber Abbasi's avatar
Baber Abbasi committed
1007
                requests=rolling_token_windows,
lintangsutawika's avatar
lintangsutawika committed
1008
1009
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1022

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    def _batch_scheduler(self, pos, n_reordered_requests):
        sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
        if sched in self.batch_sizes:
            return self.batch_sizes[sched]
        if (len(self.batch_sizes) > 1) and (
            self.batch_sizes[sched - 1] == self.max_batch_size
        ):
            # if previous batch size is already maximal, skip recomputation
            self.batch_sizes[sched] = self.max_batch_size
            return self.batch_sizes[sched]
        print(
            f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
        )
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1036
        self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1037
1038
        print(f"Determined largest batch size: {self.batch_sizes[sched]}")
        return self.batch_sizes[sched]
1039

Ethan Smith's avatar
Ethan Smith committed
1040
    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
1041
1042
1043
1044
1045
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
        override_bs: int = None,
    ) -> List[Tuple[float, bool]]:
1046
1047
1048
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

Baber Abbasi's avatar
Baber Abbasi committed
1049
        def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
Baber Abbasi's avatar
Baber Abbasi committed
1050
            """Defines the key for the sorted method"""
1051
1052
1053
1054
1055
1056
1057
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

Baber Abbasi's avatar
Baber Abbasi committed
1058
            toks = req[1] + req[2]
1059
1060
            return -len(toks), tuple(toks)

Baber Abbasi's avatar
Baber Abbasi committed
1061
1062
1063
        def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key to group and lookup one-token continuations"""
            # Use with group_by="contexts" (optional)"
Baber Abbasi's avatar
Baber Abbasi committed
1064
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
Baber Abbasi's avatar
Baber Abbasi committed
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by="contexts"
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
            and self.logits_cache
            else None,
            group_fn=_lookup_one_token_cont,
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
1078
1079
1080

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
Baber Abbasi's avatar
Baber Abbasi committed
1081
1082
1083
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
1084
1085
1086
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
Baber Abbasi's avatar
Baber Abbasi committed
1087
1088
1089
1090
            else 0
        )
        batch_fn = (
            self._batch_scheduler
1091
1092
1093
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
Baber Abbasi's avatar
Baber Abbasi committed
1094
            else None
1095
1096
        )

Baber Abbasi's avatar
Baber Abbasi committed
1097
        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1098
1099
1100
1101
1102
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
1103
        for chunk in chunks:
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
1123
                # how this all works (illustrated on a causal decoder-only setup):
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
1135
1136
                        device=self.device,
                    )
1137
1138
1139
1140
1141
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
1142
                        device=self.device,
1143
                    )
1144
                    (inplen,) = inp.shape
1145
1146
1147
1148

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

1149
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
1150
                        (continuation_enc)[-self.max_length :],
1151
1152
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
1153
                        dtype=torch.long,
1154
1155
                        device=self.device,
                    )
1156
1157
                    (contlen,) = cont.shape

1158
1159
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
1160
1161
1162
1163
1164
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
1165

haileyschoelkopf's avatar
haileyschoelkopf committed
1166
1167
1168
1169
1170
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
1171
1172
1173
1174

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
1175

1176
1177
1178
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1179
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1180
1181
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
1182
1183
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
1184
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1185
1186
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
1187
                batched_conts = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1188
1189
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
1190
                batched_encoder_mask = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1191
1192
1193
1194
1195
1196
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
1197
1198
1199

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
1200
            )  # [batch, padding_length (inp or cont), vocab]
1201

Baber Abbasi's avatar
Baber Abbasi committed
1202
            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1203
1204
1205
1206
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
1207
                # take only logits in the continuation
1208
                # (discard context toks if decoder-only ; discard right-padding)
1209
1210
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
haileyschoelkopf's avatar
haileyschoelkopf committed
1211
                ctx_len = (
1212
                    inplen + (logits.shape[0] - padding_len_inp)
haileyschoelkopf's avatar
haileyschoelkopf committed
1213
1214
1215
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
1216
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
1217
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
1218
1219
1220
1221

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

Baber Abbasi's avatar
Baber Abbasi committed
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
                for request_str, cont_toks, logits in re_ord.get_cache(
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(0)  # [1, seq]
                    max_equal = (greedy_tokens == cont_toks).all()

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    res.append(answer)

                    self.cache_hook.add_partial("loglikelihood", request_str, answer)
                    pbar.update(1)
haileyschoelkopf's avatar
haileyschoelkopf committed
1251
1252

        pbar.close()
haileyschoelkopf's avatar
haileyschoelkopf committed
1253

1254
1255
        return re_ord.get_original(res)

1256
1257
1258
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
Baber Abbasi's avatar
Baber Abbasi committed
1259
        res = []
1260

Baber Abbasi's avatar
Baber Abbasi committed
1261
        def _collate(req: Tuple[str, dict]):
Baber Abbasi's avatar
Baber Abbasi committed
1262
            """Defines the key for the sorted method"""
1263
1264
1265
1266
1267
1268
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
Baber Abbasi's avatar
Baber Abbasi committed
1269
1270
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]
1271

1272
1273
        pbar = tqdm(
            total=len(requests),
1274
            disable=(disable_tqdm or (self.rank != 0)),
1275
1276
            desc="Running generate_until requests",
        )
Baber Abbasi's avatar
Baber Abbasi committed
1277
        adaptive_batch_size = None
1278
1279
1280
1281
1282
1283
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
1284
        # for each different set of kwargs, we execute all requests, by batch.
Baber Abbasi's avatar
Baber Abbasi committed
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )
1297

Baber Abbasi's avatar
Baber Abbasi committed
1298
1299
1300
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
1301
1302
1303
1304
1305
1306
1307
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            [reg.args for reg in requests],
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
Baber Abbasi's avatar
Baber Abbasi committed
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
achervyakov's avatar
achervyakov committed
1321
                        until = [until]
Baber Abbasi's avatar
Baber Abbasi committed
1322
1323
1324
1325
1326
1327
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
Baber Abbasi's avatar
Baber Abbasi committed
1328
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1329
                )
1330
            # add EOS token to stop sequences
Lintang Sutawika's avatar
Lintang Sutawika committed
1331
            eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
Baber Abbasi's avatar
Baber Abbasi committed
1332
            if not until:
1333
1334
1335
                until = [eos]
            else:
                until.append(eos)
Baber Abbasi's avatar
Baber Abbasi committed
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)
1357

Baber Abbasi's avatar
Baber Abbasi committed
1358
1359
            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1360

Baber Abbasi's avatar
Baber Abbasi committed
1361
1362
1363
1364
1365
1366
1367
            # perform batched generation
            cont = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
1368

Baber Abbasi's avatar
Baber Abbasi committed
1369
1370
1371
1372
1373
            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    cont_toks = cont_toks[context_enc.shape[1] :]
1374

Baber Abbasi's avatar
Baber Abbasi committed
1375
                s = self.tok_decode(cont_toks)
1376

Baber Abbasi's avatar
Baber Abbasi committed
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                for term in until:
                    if len(term) > 0:
                        # ignore '' separator,
                        # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
                        s = s.split(term)[0]

                res.append(s)

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
1390

1391
        pbar.close()
1392

Baber Abbasi's avatar
Baber Abbasi committed
1393
        return res
1394

KonradSzafer's avatar
KonradSzafer committed
1395
1396
1397
1398
1399
1400
1401
1402
    def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
        return self.tokenizer.apply_chat_template(
            chat_history, tokenize=False, add_generation_prompt=True
        )

1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
    def get_model_info(self) -> dict:
        """
        Method to get Hugging Face model information for experiment reproducibility.
        """

        def get_model_num_params(model) -> int:
            if hasattr(model, "num_parameters"):
                return model.num_parameters()
            if hasattr(model, "parameters"):
                return sum(p.numel() for p in model.parameters())
            else:
                return -1

        def get_model_dtype(model) -> str:
            if hasattr(model, "dtype"):
                return model.dtype
            else:
                return ""

        def get_model_sha(pretrained: str, revision: str) -> str:
            try:
                model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
                return model_info.sha
            except Exception as e:
                eval_logger.warn(
                    f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
                )
                return ""

        model_info = {
            "model_num_parameters": get_model_num_params(self._model),
            "model_dtype": get_model_dtype(self._model),
            "model_revision": self.revision,
            "model_sha": get_model_sha(self.pretrained, self.revision),
        }
        if self.peft:
            model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
        if self.delta:
            model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
        return model_info