huggingface.py 63.4 KB
Newer Older
Baber's avatar
types  
Baber committed
1
2
from __future__ import annotations

3
import copy
Lintang Sutawika's avatar
Lintang Sutawika committed
4
import logging
5
import os
Jeevan's avatar
Jeevan committed
6
from datetime import timedelta
7
from pathlib import Path
Baber's avatar
types  
Baber committed
8
from typing import TYPE_CHECKING, Literal
9

10
import jinja2
11
import torch
12
import torch.nn.functional as F
13
import transformers
Jeevan's avatar
Jeevan committed
14
15
16
17
18
from accelerate import (
    Accelerator,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)
Nathan Habib's avatar
Nathan Habib committed
19
from accelerate.utils import get_max_memory
20
from huggingface_hub import HfApi
21
22
from packaging import version
from tqdm import tqdm
23
24
25
26
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
27
28

from lm_eval import utils
baberabb's avatar
baberabb committed
29
from lm_eval.api.instance import Instance
30
from lm_eval.api.model import TemplateLM
31
from lm_eval.api.registry import register_model
32
33
34
from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
35
    configure_pad_token,
36
    get_dtype,
37
    handle_stop_sequences,
38
    pad_and_concat,
39
    postprocess_generated_text,
40
41
    stop_sequences_criteria,
)
42

43

44
if TYPE_CHECKING:
Baber's avatar
types  
Baber committed
45
    from transformers.quantizers.auto import AutoQuantizationConfig
46

Lintang Sutawika's avatar
Lintang Sutawika committed
47
eval_logger = logging.getLogger(__name__)
48

lintangsutawika's avatar
lintangsutawika committed
49

50
@register_model("hf-auto", "hf", "huggingface")
51
class HFLM(TemplateLM):
52
53
54
55
56
57
58
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

59
    AUTO_MODEL_CLASS = None
60
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
61

62
63
    def __init__(
        self,
Baber's avatar
types  
Baber committed
64
        pretrained: str | transformers.PreTrainedModel,
65
        backend: Literal["default", "causal", "seq2seq"] = "default",
Baber Abbasi's avatar
Baber Abbasi committed
66
        # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
Baber's avatar
types  
Baber committed
67
        revision: str | None = "main",
68
        subfolder: str = "",
Baber's avatar
types  
Baber committed
69
70
71
72
73
        tokenizer: str
        | transformers.PreTrainedTokenizer
        | transformers.PreTrainedTokenizerFast
        | None = None,
        truncation: bool | None = False,
Baber Abbasi's avatar
Baber Abbasi committed
74
        logits_cache: bool = True,
Baber's avatar
types  
Baber committed
75
76
77
78
79
80
81
82
83
84
85
        max_length: int | None = None,
        device: str | None = "cuda",
        dtype: str | torch.dtype | None = "auto",
        softmax_dtype: str | torch.dtype | None = None,
        mixed_precision_dtype: str | torch.dtype | None = None,
        batch_size: int | str | None = 1,
        max_batch_size: int | None = 64,
        trust_remote_code: bool | None = False,
        use_fast_tokenizer: bool | None = True,
        add_bos_token: bool | None = False,
        prefix_token_id: int | None = None,
86
        # arguments used for splitting a model across GPUs naively.
87
        # only used if `parallelize=True`.
Baber's avatar
types  
Baber committed
88
89
90
91
        parallelize: bool | None = False,
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | os.PathLike | None = "./offload",
92
        # PEFT, delta weights and quantization options
Baber's avatar
types  
Baber committed
93
94
95
96
97
        peft: str | None = None,
        delta: str | None = None,
        autogptq: bool | str | None = False,
        gptqmodel: bool | None = False,
        gguf_file: str | None = None,
98
99
        # end token for thinking, either the string or int token id.
        # splits to get response after this token (if provided).
Baber's avatar
types  
Baber committed
100
        think_end_token: str | int | None = None,
101
        **kwargs,
Ethan Smith's avatar
Ethan Smith committed
102
    ) -> None:
103
        super().__init__()
104
105
106
107
        # optionally: take in an already-initialized transformers.PreTrainedModel
        if not isinstance(pretrained, str):
            eval_logger.warning(
                "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
108
            )
Baber Abbasi's avatar
Baber Abbasi committed
109
110
111
            assert not parallelize, (
                "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
            )
112
113
114
            self._model = pretrained
            self._device = self._model.device
            self._config = self._model.config
Baber Abbasi's avatar
Baber Abbasi committed
115
            gpus = 0
116

117
        else:
118
119
120
121
122
            assert isinstance(device, str)
            assert isinstance(pretrained, str)
            assert isinstance(batch_size, (int, str))

            gpus = torch.cuda.device_count()
Jeevan's avatar
Jeevan committed
123
124
            accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
            accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
125
126
            if accelerator.num_processes > 1:
                self.accelerator = accelerator
127

128
129
130
            if "npu" in accelerator.device.type:
                gpus = torch.npu.device_count()

Nathan Habib's avatar
Nathan Habib committed
131
            # using one process with no model parallelism
132
133
134
135
            if not (parallelize or accelerator.num_processes > 1):
                # use user-passed device
                device_list = set(
                    ["cuda", "cpu"]
136
                    + [f"cuda:{i}" for i in range(gpus)]
137
                    + ["mps", "mps:0"]
138
                    + [f"npu:{i}" for i in range(gpus)]
139
                )
140
                if device and device in device_list:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                    self._device = torch.device(device)
                    eval_logger.info(f"Using device '{device}'")
                    if device in ("mps", "mps:0") and version.parse(
                        torch.__version__
                    ) < version.parse("2.1"):
                        raise RuntimeError(
                            f"mps requires torch >= 2.1. You have {torch.__version__}"
                        )
                else:
                    eval_logger.info("Device not specified")
                    eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                    self._device = (
                        torch.device("cuda")
                        if torch.cuda.is_available()
                        else torch.device("cpu")
                    )
Nathan Habib's avatar
Nathan Habib committed
157
            else:  # Parallelism managed by accelerate
158
159
160
161
162
                if device != "cuda":
                    eval_logger.info(
                        f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                    )
                # TODO: include in warning that `load_in_8bit` etc. affect this too
Nathan Habib's avatar
Nathan Habib committed
163
164
165
166
167
                self._device = (
                    self.accelerator.device
                    if hasattr(self, "accelerator")
                    else torch.device(device)
                )
168

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
169
            revision = str(revision)  # cast to string if not already one
170

171
            self._get_config(
172
173
174
                pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
175
                gguf_file=gguf_file,
176
                subfolder=subfolder,
177
178
            )

179
            # determine which of 'causal' and 'seq2seq' backends to use for HF models
180
181
182
        self._get_backend(
            config=self.config, backend=backend, trust_remote_code=trust_remote_code
        )
183

184
185
186
187
188
        # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
        self._create_tokenizer(
            pretrained,
            tokenizer,
            revision=revision,
189
            subfolder=subfolder,
190
191
            trust_remote_code=trust_remote_code,
            use_fast_tokenizer=use_fast_tokenizer,
192
            gguf_file=gguf_file,
193
            add_bos_token=add_bos_token,
194
195
        )

196
197
198
199
200
201
202
        if (
            quantization_config := getattr(self.config, "quantization_config", None)
        ) is not None and isinstance(quantization_config, dict):
            from transformers.quantizers import AutoQuantizationConfig

            quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

203
204
205
206
207
208
209
210
        # if we passed `pretrained` as a string, initialize our model now
        if isinstance(pretrained, str):
            self._create_model(
                pretrained=pretrained,
                revision=revision,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                parallelize=parallelize,
211
                gpus=gpus,
212
213
214
215
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                peft=peft,
216
                delta=delta,
217
                autogptq=autogptq,
218
                gptqmodel=gptqmodel,
219
                gguf_file=gguf_file,
220
                quantization_config=quantization_config,
221
                subfolder=subfolder,
222
                **kwargs,
223
224
            )

225
        # access self._model through self.model property outside this method
226
227
228
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            self.model.tie_weights()
haileyschoelkopf's avatar
haileyschoelkopf committed
229

230
231
232
233
234
        self.think_end_token = (
            int(think_end_token)
            if (isinstance(think_end_token, str) and think_end_token.isdigit())
            else think_end_token
        )
lintangsutawika's avatar
lintangsutawika committed
235
        self.truncation = truncation
Baber Abbasi's avatar
Baber Abbasi committed
236
        self.logits_cache = logits_cache
237
        self.vocab_size = self.tokenizer.vocab_size
238
        # select (or create) a pad token to use
239
        self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
240

241
        self.add_bos_token = add_bos_token
242
        if "gemma" in getattr(self.config, "model_type", ""):
243
            self.add_bos_token = True
244
            eval_logger.info(
245
                f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
246
247
            )

248
        self._max_length = max_length
249
250
251
252
        self.pretrained = pretrained
        self.delta = delta
        self.peft = peft
        self.revision = revision
Benjamin Fattori's avatar
Benjamin Fattori committed
253
254
255
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size
256
257
258
        self.softmax_dtype = (
            get_dtype(softmax_dtype) if softmax_dtype is not None else None
        )
259
260
261
262
263
        self.mixed_precision_dtype = (
            get_dtype(mixed_precision_dtype)
            if mixed_precision_dtype is not None
            else None
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
264
265
266
267
268
269
270

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
271

272
        if isinstance(pretrained, str):
Baber's avatar
types  
Baber committed
273
274
275
            if (gpus >= 1 or str(self.device) == "mps") and not (
                parallelize or autogptq or hasattr(self, "accelerator")
            ):
Nathan Habib's avatar
Nathan Habib committed
276
                # TODO: can remove this whole snippet except in the mps case, perhaps?
Baber's avatar
types  
Baber committed
277
278
279
280
281
282
283
284
285
                # place model onto device requested manually,
                # if not using HF Accelerate or device_map
                # or any other option that preloads model onto device
                try:
                    self.model.to(self.device)
                except ValueError:
                    eval_logger.debug(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
                    )
286
287
            # multigpu data-parallel support when launched with accelerate
            if gpus > 1:
Nathan Habib's avatar
Nathan Habib committed
288
289
290
291
                if accelerator.num_processes > 1:
                    if parallelize:
                        eval_logger.warning(
                            "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
292
                        )
Nathan Habib's avatar
Nathan Habib committed
293
                    elif gpus > accelerator.num_processes:
294
295
296
297
298
299
                        eval_logger.warning(
                            "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                            "If you would like to use data parallelism, please launch the script "
                            "with 'accelerate launch *script*'. "
                            f"Current run will proceed with {accelerator.num_processes} devices."
                        )
Nathan Habib's avatar
Nathan Habib committed
300
301
302
303
304
                        if self.accelerator.is_local_main_process:
                            eval_logger.info(
                                f"Using {gpus} devices with data parallelism"
                            )

305
                    self._device = torch.device(f"{accelerator.device}")
306
                    self.accelerator = accelerator
307

308
309
                    self._rank = self.accelerator.local_process_index
                    self._world_size = self.accelerator.num_processes
Nathan Habib's avatar
Nathan Habib committed
310
311
312
313
                else:
                    # if we aren't launching via accelerate, ditch
                    self._rank = 0
                    self._world_size = 1
314
315
316
317
318
319
320
        else:
            # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
            eval_logger.warning(
                "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
            )
            self._rank = 0
            self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
321

322
        self.custom_prefix_token_id = prefix_token_id
323
324
325
326
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
327

Nathan Habib's avatar
Nathan Habib committed
328
329
    def _get_accelerate_args(
        self,
Baber's avatar
types  
Baber committed
330
331
332
333
334
335
        parallelize: bool | None = None,
        device_map: str | None = "auto",
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | None = "./offload",
        gpus: int | None = None,
Nathan Habib's avatar
Nathan Habib committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    ) -> dict:
        """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
        num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
        num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
        if (
            num_machines == 0
            and hasattr(self, "accelerator")
            and self.accelerator is not None
        ):
            eval_logger.info(
                "We are not in a distributed setting for accelerate. Setting model_parallel to False."
            )
            parallelize = False

        if parallelize is None:
            # If parallelism is unset by the user, we automatically assign model parallelism
            # if enough extra GPUs are available
            max_memory_all_gpus = get_max_memory()
            # We just want gpu, not cpu, max memory
            if "cpu" in max_memory_all_gpus:
                del max_memory_all_gpus["cpu"]
            parallelize = bool(num_local_processes < len(max_memory_all_gpus))
            eval_logger.info(
                f"Setting model parallel to {parallelize} since "
                f"the number of local processes is {num_local_processes} "
                f"and the number of GPUs is {len(max_memory_all_gpus)}"
            )

        args = {}
        if parallelize:  # Model parallelism will be used
            max_memory = {}
            if max_memory_per_gpu is not None:  # Using the provided memory requirements
                max_memory_per_gpu_map = {
                    device_idx: max_memory_per_gpu for device_idx in range(gpus)
                }
            else:  # Estimating the possible memory requirements
                max_memory_all_gpus = get_max_memory()
                if "cpu" in max_memory_all_gpus:
                    del max_memory_all_gpus["cpu"]
                if not hasattr(self, "accelerator"):
                    max_memory_per_gpu_map = {
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
377
                        k: v for k, v in max_memory_all_gpus.items()
Nathan Habib's avatar
Nathan Habib committed
378
                    }
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
379
                else:
Nathan Habib's avatar
Nathan Habib committed
380
381
382
383
384
385
386
387
                    # use only 1 / num_processes of the GPUs if we are running under accelerate launch
                    max_memory_per_gpu_map = {
                        k: v
                        for k, v in max_memory_all_gpus.items()
                        if k % num_local_processes
                        == (self.accelerator.process_index % num_local_processes)
                    }
            args["max_memory"] = max_memory_per_gpu_map
388
            args["device_map"] = "auto" if device_map is None else device_map
Nathan Habib's avatar
Nathan Habib committed
389
            eval_logger.info(
390
                f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
Nathan Habib's avatar
Nathan Habib committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            )

            if max_cpu_memory is not None:
                max_memory["cpu"] = max_cpu_memory

            args["offload_folder"] = offload_folder
        elif (
            device_map is None
        ):  # No model parallelism, we use the default provided device for our model
            if hasattr(self, "accelerator"):
                device_map = {"": f"{self.accelerator.device}"}
            else:
                device_map = {"": str(self.device)}
            args["max_memory"] = None
            args["device_map"] = device_map
            eval_logger.info(
                f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
            )
        else:
            args["max_memory"] = None
            args["device_map"] = None
            eval_logger.info("Model parallel was set to False.")

        return args

416
417
418
419
420
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

421
422
423
424
425
426
427
428
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

429
430
431
432
433
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

434
435
436
437
438
439
440
441
442
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

443
444
    @property
    def max_length(self):
445
446
447
448
449
450
451
452
453
454
455
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
456

457
    @property
Ethan Smith's avatar
Ethan Smith committed
458
    def max_gen_toks(self) -> int:
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

KonradSzafer's avatar
KonradSzafer committed
477
478
479
480
    @property
    def tokenizer_name(self) -> str:
        return self.tokenizer.name_or_path.replace("/", "__")

481
482
    def _get_backend(
        self,
Baber's avatar
types  
Baber committed
483
        config: transformers.PretrainedConfig | transformers.AutoConfig,
484
        backend: Literal["default", "causal", "seq2seq"] = "default",
Baber's avatar
types  
Baber committed
485
        trust_remote_code: bool | None = False,
486
487
488
    ) -> None:
        """
        Helper method during initialization.
489
        Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
490
        sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
491
492
493

        **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
        user must set `self.backend` to be either "causal" or "seq2seq" manually!**
494
        """
495

496
497
498
499
        assert backend in ["default", "causal", "seq2seq"]

        if backend != "default":
            # if we've settled on non-default backend, use that manually
Baber's avatar
types  
Baber committed
500
            if backend in ["causal", "seq2seq"]:
501
                self.backend = backend
502
            eval_logger.info(
503
                f"Overrode HF model backend type, and using type '{self.backend}'"
504
505
506
            )
        else:
            # determine and use the default HF backend for this model, based on its config + metadata.
Baber's avatar
types  
Baber committed
507
            if self.config.model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
508
509
510
                # first check if model type is listed under seq2seq models, since some
                # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
                # these special cases should be treated as seq2seq models.
511
                self.backend = "seq2seq"
512
                eval_logger.debug(f"Using model type '{self.backend}'")
Baber's avatar
types  
Baber committed
513
            elif self.config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
514
                self.backend = "causal"
515
                eval_logger.debug(f"Using model type '{self.backend}'")
516
517
518
519
520
            else:
                if not trust_remote_code:
                    eval_logger.warning(
                        "HF model type is neither marked as CausalLM or Seq2SeqLM. \
                    This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
521
                        "Setting backend to causal"
522
523
                    )
                # if model type is neither in HF transformers causal or seq2seq model registries
524
525
526
                # then we default to assuming AutoModelForCausalLM
                self.backend = "causal"
                eval_logger.info(
527
                    f"Model type cannot be determined. Using default model type '{self.backend}'"
528
                )
529

530
531
532
533
534
        if self.AUTO_MODEL_CLASS is None:
            if self.backend == "causal":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            elif self.backend == "seq2seq":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
535
536
537
538
539
540

    def _get_config(
        self,
        pretrained: str,
        revision: str = "main",
        trust_remote_code: bool = False,
Baber's avatar
types  
Baber committed
541
        gguf_file: str | None = None,
542
        subfolder: str = "",
543
    ) -> None:
544
        """Return the model config for HuggingFace models"""
545
546
547
548
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
549
            gguf_file=gguf_file,
550
            subfolder=subfolder,
551
552
553
554
555
        )

    def _create_model(
        self,
        pretrained: str,
Baber's avatar
types  
Baber committed
556
557
558
        revision: str | None = "main",
        dtype: str | torch.dtype | None = "auto",
        trust_remote_code: bool | None = False,
559
560
561
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # (accelerate naive PP (device_map) options)
Baber's avatar
types  
Baber committed
562
563
564
565
566
        parallelize: bool | None = False,
        gpus: int | None = None,
        max_memory_per_gpu: int | str | None = None,
        max_cpu_memory: int | str | None = None,
        offload_folder: str | None = "./offload",
567
        # PEFT, delta weights and quantization options
Baber's avatar
types  
Baber committed
568
569
570
571
572
573
        peft: str | None = None,
        delta: str | None = None,
        autogptq: bool | str | None = False,
        gptqmodel: bool | None = False,
        gguf_file: str | None = None,
        quantization_config: AutoQuantizationConfig | None = None,
574
        subfolder: str = "",
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        **kwargs,
    ) -> None:
        """
        Initializes an HF or HF-compatible PreTrainedModel from scratch
        inside HFLM, using the kwargs passed into self.__init__().

        Also handles functionality such as AutoGPTQ usage and PEFT wrapping.

        For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
        (such as PyTorch models that are nearly, but not quite, fully mirroring
        HF's public interface relied on in this HFLM class)
        please consider subclassing HFLM and overriding this and other methods as needed.
        """

        model_kwargs = kwargs if kwargs else {}

Nathan Habib's avatar
Nathan Habib committed
591
592
593
        model_kwargs.update(
            self._get_accelerate_args(
                parallelize=parallelize,
Baber's avatar
types  
Baber committed
594
                device_map=kwargs.get("device_map"),
Nathan Habib's avatar
Nathan Habib committed
595
596
597
598
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                gpus=gpus,
599
            )
Nathan Habib's avatar
Nathan Habib committed
600
        )
601

602
        if not autogptq and not gptqmodel:
603
            if model_kwargs.get("load_in_4bit", None):
Baber Abbasi's avatar
Baber Abbasi committed
604
605
606
                assert transformers.__version__ >= "4.30.0", (
                    "load_in_4bit requires transformers >= 4.30.0"
                )
Baber's avatar
types  
Baber committed
607
608
609
610
611
            if transformers.__version__ >= "4.30.0" and (
                model_kwargs.get("load_in_4bit")
                and (compute_dtype := model_kwargs.get("bnb_4bit_compute_dtype"))
            ):
                model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(compute_dtype)
Nathan Habib's avatar
Nathan Habib committed
612

613
614
615
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
616
                torch_dtype=get_dtype(dtype),
617
                trust_remote_code=trust_remote_code,
618
                gguf_file=gguf_file,
619
                quantization_config=quantization_config,
620
                subfolder=subfolder,
621
622
623
                **model_kwargs,
            )
        else:
624
625
626
            if autogptq and gptqmodel:
                raise ValueError(
                    "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
627
628
                )

629
630
631
632
633
634
635
            if autogptq:
                try:
                    from auto_gptq import AutoGPTQForCausalLM
                except ModuleNotFoundError as exception:
                    raise type(exception)(
                        "Tried to load auto_gptq, but auto-gptq is not installed ",
                        "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
Baber's avatar
types  
Baber committed
636
                    ) from exception
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

                self._model = AutoGPTQForCausalLM.from_quantized(
                    pretrained,
                    trust_remote_code=trust_remote_code,
                    model_basename=None if autogptq is True else Path(autogptq).stem,
                    use_safetensors=True
                    if autogptq is True
                    else autogptq.endswith(".safetensors"),
                    **model_kwargs,
                )

            if gptqmodel:
                try:
                    from gptqmodel import GPTQModel
                except ModuleNotFoundError as exception:
                    raise type(exception)(
                        "Tried to load gptqmodel, but gptqmodel is not installed ",
                        "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
Baber's avatar
types  
Baber committed
655
                    ) from exception
656
657
658
659

                self._model = GPTQModel.from_quantized(
                    pretrained, trust_remote_code=trust_remote_code, **model_kwargs
                )
660

661
662
663
664
665
        if peft and delta:
            raise ValueError(
                "Cannot use both 'peft' and 'delta' options at the same time."
            )

666
        if peft:
Baber's avatar
types  
Baber committed
667
            from peft import PeftModel, __version__ as PEFT_VERSION
668

Baber's avatar
types  
Baber committed
669
670
671
672
            if model_kwargs.get("load_in_4bit") and version.parse(
                PEFT_VERSION
            ) < version.parse("0.4.0"):
                raise AssertionError("load_in_4bit requires peft >= 0.4.0")
673
674
            if self._model.config.vocab_size != len(self.tokenizer):
                # resize model for LoRAs with added tokens
675
676
677
                eval_logger.info(
                    f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
                )
678
                self._model.resize_token_embeddings(len(self.tokenizer))
679
680
681
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        elif delta:
            if autogptq:
                eval_logger.warning(
                    "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
                )
            _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
                delta,
                revision=revision,
                torch_dtype=get_dtype(dtype),
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
            for name, param in self._model.state_dict().items():
                try:
                    param.data += _model_delta.state_dict()[name]
                except KeyError:
Baber's avatar
types  
Baber committed
698
699
700
                    raise KeyError(
                        f"Delta model is missing weights for layer: {name}"
                    ) from None
701
702
703
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to add delta weights to layer {name}. Error: {e}"
Baber's avatar
types  
Baber committed
704
                    ) from e
705
706

            del _model_delta
707
708
709
710
711

        return None

    def _create_tokenizer(
        self,
Baber's avatar
types  
Baber committed
712
713
714
715
716
717
718
719
720
721
722
        pretrained: str | transformers.PreTrainedModel,
        tokenizer: str
        | transformers.PreTrainedTokenizer
        | transformers.PreTrainedTokenizerFast
        | None,
        revision: str | None = "main",
        trust_remote_code: bool | None = False,
        use_fast_tokenizer: bool | None = True,
        gguf_file: str | None = None,
        add_bos_token: bool | None = False,
        subfolder: str | None = "",
723
724
725
726
727
728
729
    ) -> None:
        """
        Helper method during initialization.

        Create a tokenizer object corresponding to the correct
        tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
        """
730
731
732
733
734
735
        kwargs = {
            "revision": revision,
            "trust_remote_code": trust_remote_code,
        }

        # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
736
        if not tokenizer and gguf_file is not None:
737
738
739
            kwargs["gguf_file"] = gguf_file
        else:
            kwargs["use_fast"] = use_fast_tokenizer
740

741
742
743
        if add_bos_token:
            kwargs["add_bos_token"] = True

744
745
746
        if subfolder:
            kwargs["subfolder"] = subfolder

747
748
749
        if tokenizer:
            if isinstance(tokenizer, str):
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
750
                    tokenizer, **kwargs
751
752
753
                )
            else:
                assert isinstance(
Baber's avatar
types  
Baber committed
754
755
756
757
758
759
                    tokenizer,
                    (
                        transformers.PreTrainedTokenizer,
                        transformers.PreTrainedTokenizerFast,
                    ),
                )
760
761
762
763
764
765
766
767
768
                self.tokenizer = tokenizer
        else:
            # Get tokenizer based on 'pretrained'
            if isinstance(pretrained, str):
                model_name = pretrained
            else:
                # get the HF hub name via accessor on model
                model_name = self.model.name_or_path
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
769
                model_name, **kwargs
770
771
772
            )
        return None

Ethan Smith's avatar
Ethan Smith committed
773
    def _detect_batch_size(self, requests=None, pos: int = 0):
Benjamin Fattori's avatar
Benjamin Fattori committed
774
775
776
777
778
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
779
780
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
Benjamin Fattori's avatar
Benjamin Fattori committed
781
782
        else:
            max_length = self.max_length
783
784
            max_context_enc = max_length
            max_cont_enc = max_length
lintangsutawika's avatar
lintangsutawika committed
785

Benjamin Fattori's avatar
Benjamin Fattori committed
786
787
788
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
789
            if self.backend == "seq2seq":
790
                length = max(max_context_enc, max_cont_enc)
lintangsutawika's avatar
lintangsutawika committed
791
792
793
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
794
795
                test_batch = torch.ones((batch_size, length), device=self.device).long()
                call_kwargs = {
lintangsutawika's avatar
lintangsutawika committed
796
797
798
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
799
800
            else:
                call_kwargs = {}
lintangsutawika's avatar
lintangsutawika committed
801
802
803
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
Benjamin Fattori's avatar
Benjamin Fattori committed
804
            for _ in range(5):
805
806
807
808
809
                out = F.log_softmax(  # noqa: F841
                    self._model_call(test_batch, **call_kwargs),
                    dim=-1,
                    dtype=self.softmax_dtype,
                )
lintangsutawika's avatar
lintangsutawika committed
810

Benjamin Fattori's avatar
Benjamin Fattori committed
811
812
            return batch_size

813
814
815
816
817
818
819
        try:
            batch_size = forward_batch()
        except RuntimeError as e:
            if "No executable batch size found" in str(e):
                batch_size = 1
            else:
                raise
Benjamin Fattori's avatar
Benjamin Fattori committed
820

821
822
823
824
825
826
827
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
828
            clear_torch_cache()
829
830
            return batch_size

831
        clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
832
833
        return batch_size

baberabb's avatar
baberabb committed
834
835
    def tok_encode(
        self, string: str, left_truncate_len=None, add_special_tokens=None
Baber's avatar
types  
Baber committed
836
    ) -> list[int]:
haileyschoelkopf's avatar
haileyschoelkopf committed
837
        """ """
Lintang Sutawika's avatar
Lintang Sutawika committed
838
839
840
841
842
        # default for None - empty dict, use predefined tokenizer param
        # used for all models except for CausalLM or predefined value
        special_tokens_kwargs = {}

        # by default for CausalLM - false or self.add_bos_token is set
843
        if add_special_tokens is None:
844
            if self.backend == "causal":
Lintang Sutawika's avatar
Lintang Sutawika committed
845
846
847
848
849
850
                special_tokens_kwargs = {
                    "add_special_tokens": False or self.add_bos_token
                }
        # otherwise the method explicitly defines the value
        else:
            special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
851

Lintang Sutawika's avatar
Lintang Sutawika committed
852
        encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
haileyschoelkopf's avatar
haileyschoelkopf committed
853

854
855
856
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
857

858
859
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
860
    def tok_batch_encode(
lintangsutawika's avatar
lintangsutawika committed
861
        self,
Baber's avatar
types  
Baber committed
862
        strings: list[str],
lintangsutawika's avatar
lintangsutawika committed
863
        padding_side: str = "left",
864
865
        left_truncate_len: int = None,
        truncation: bool = False,
Baber's avatar
types  
Baber committed
866
    ) -> tuple[torch.Tensor, torch.Tensor]:
haileyschoelkopf's avatar
haileyschoelkopf committed
867
868
869
870
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

Lintang Sutawika's avatar
Lintang Sutawika committed
871
        add_special_tokens = {}
872
        if self.backend == "causal":
Lintang Sutawika's avatar
Lintang Sutawika committed
873
            add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
haileyschoelkopf's avatar
haileyschoelkopf committed
874
875
876

        encoding = self.tokenizer(
            strings,
lintangsutawika's avatar
lintangsutawika committed
877
            truncation=truncation,
haileyschoelkopf's avatar
haileyschoelkopf committed
878
879
            padding="longest",
            return_tensors="pt",
Lintang Sutawika's avatar
Lintang Sutawika committed
880
            **add_special_tokens,
haileyschoelkopf's avatar
haileyschoelkopf committed
881
882
        )
        if left_truncate_len:
883
884
885
886
887
888
            original_lengths = encoding["input_ids"].size(1)
            if original_lengths > left_truncate_len:
                eval_logger.warn(
                    f"Left truncation applied. Original sequence length was {original_lengths}, "
                    f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
                )
haileyschoelkopf's avatar
haileyschoelkopf committed
889
890
891
892
893
894
895
896
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

Lintang Sutawika's avatar
Lintang Sutawika committed
897
898
    def tok_decode(self, tokens, skip_special_tokens=True):
        return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
899
900
901

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
902
        :param inps: torch.Tensor
903
904
905
906
907
908
909
910
911
912
913
914
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
Baber's avatar
types  
Baber committed
915
916
917
        with (
            torch.no_grad(),
            torch.autocast(
918
919
920
                device_type=self.device.type,
                dtype=self.mixed_precision_dtype,
                enabled=self.mixed_precision_dtype is not None,
Baber's avatar
types  
Baber committed
921
922
923
924
925
926
927
928
929
930
931
932
933
934
            ),
        ):
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
                assert transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
            else:
                assert self.AUTO_MODEL_CLASS in (
                    transformers.AutoModelForCausalLM,
                    transformers.AutoModelForVision2Seq,
                )
                return self.model(inps).logits
935
936

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
937
        # temperature = 0.0 if not set
938
939
940
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
Baber Abbasi's avatar
Baber Abbasi committed
941
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
Baber's avatar
types  
Baber committed
942
        do_sample = generation_kwargs.get("do_sample")
943
944
945
946
947

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

Baber Abbasi's avatar
Baber Abbasi committed
948
949
        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
950
951
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
952
            self.tokenizer, stop, context.shape[1], context.shape[0]
953
        )
954
955
956
957
958
959
960
961
962
963
964
965
966
        with torch.autocast(
            device_type=self.device.type,
            dtype=self.mixed_precision_dtype,
            enabled=self.mixed_precision_dtype is not None,
        ):
            return self.model.generate(
                input_ids=context,
                max_length=max_length,
                stopping_criteria=stopping_criteria,
                pad_token_id=self.tokenizer.pad_token_id,
                use_cache=True,
                **generation_kwargs,
            )
967

Baber Abbasi's avatar
Baber Abbasi committed
968
969
970
    def _select_cont_toks(
        self, logits: torch.Tensor, contlen: int = None, inplen: int = None
    ) -> torch.Tensor:
971
        if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
972
973
974
            assert contlen and inplen, (
                "Must pass input len and cont. len to select scored logits for causal LM"
            )
975
976
977
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
978
        elif self.backend == "seq2seq":
Baber Abbasi's avatar
Baber Abbasi committed
979
980
981
            assert contlen and not inplen, (
                "Selecting scored logits for Seq2SeqLM requires only cont. len"
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
982
            # only discard right-padding.
983
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
984
985
            logits = logits[:contlen]

986
987
        return logits

988
    def loglikelihood_rolling(
Baber's avatar
types  
Baber committed
989
990
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[float]:
Benjamin Fattori's avatar
Benjamin Fattori committed
991
992
993
994
995
996
997
998
        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

999
1000
1001
1002
1003
1004
1005
1006
1007
        # First, collect all windows from all requests
        all_windows = []  # List of (request_idx, window) tuples
        request_window_counts = []  # Track number of windows per request

        for req_idx, (string,) in enumerate(
            tqdm(
                [req.args for req in requests],
                disable=(disable_tqdm or (self.rank != 0)),
            )
1008
        ):
Baber's avatar
types  
Baber committed
1009
            rolling_token_windows: list[tuple[list[int], list[int]]] = list(
1010
1011
1012
1013
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
1014
                        prefix_token=self.prefix_token_id,
1015
1016
1017
1018
1019
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
1020
1021

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
1022
            windows = [(None,) + x for x in rolling_token_windows]
1023

1024
1025
1026
            # Store windows with their request index
            all_windows.extend((req_idx, window) for window in windows)
            request_window_counts.append(len(windows))
1027

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        # Handle distributed case padding
        pad_amnt = 0
        if self.world_size > 1:
            mytensor = torch.tensor(len(all_windows), device=self.device)
            gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
            pad_amnt = max(gathered) - gathered[self.rank]
            if pad_amnt > 0:
                all_windows += pad_amnt * [all_windows[0]]

        all_nlls = []
        batch_size = adaptive_batch_size or self.batch_size
        for i in range(0, len(all_windows), batch_size):
            batch = all_windows[i : i + batch_size]
            # Extract just the windows for processing, keeping track of request indices
            batch_indices, batch_windows = zip(*batch)

            batch_nlls = self._loglikelihood_tokens(
                requests=batch_windows,
                disable_tqdm=False,
                override_bs=len(batch_windows),
1048
            )
1049
1050
            # Store results with their request indices
            all_nlls.extend(zip(batch_indices, batch_nlls))
1051

1052
1053
1054
        # Remove padding if necessary
        if (self.world_size > 1) and (pad_amnt > 0):
            all_nlls = all_nlls[:-pad_amnt]
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        # Reconstruct per-request loglikelihoods
        loglikelihoods = []
        current_idx = 0
        for window_count in request_window_counts:
            # Get all nlls for this request
            request_nlls = all_nlls[current_idx : current_idx + window_count]
            # Sum up the nlls for this request (discarding is_greedy)
            request_total = sum(nll[0] for _, nll in request_nlls)
            loglikelihoods.append(request_total)
            current_idx += window_count

            string = requests[len(loglikelihoods) - 1].args[0]
            self.cache_hook.add_partial(
                "loglikelihood_rolling", (string,), request_total
            )
1071

1072
        return loglikelihoods
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    def _batch_scheduler(self, pos, n_reordered_requests):
        sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
        if sched in self.batch_sizes:
            return self.batch_sizes[sched]
        if (len(self.batch_sizes) > 1) and (
            self.batch_sizes[sched - 1] == self.max_batch_size
        ):
            # if previous batch size is already maximal, skip recomputation
            self.batch_sizes[sched] = self.max_batch_size
            return self.batch_sizes[sched]
        print(
            f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
        )
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
1087
        self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1088
1089
        print(f"Determined largest batch size: {self.batch_sizes[sched]}")
        return self.batch_sizes[sched]
1090

Ethan Smith's avatar
Ethan Smith committed
1091
    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
1092
        self,
Baber's avatar
types  
Baber committed
1093
        requests: list[tuple[tuple[str, str], list[int], list[int]]],
baberabb's avatar
baberabb committed
1094
1095
        disable_tqdm: bool = False,
        override_bs: int = None,
Baber's avatar
types  
Baber committed
1096
    ) -> list[tuple[float, bool]]:
1097
1098
1099
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

Baber's avatar
types  
Baber committed
1100
        def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
Baber Abbasi's avatar
Baber Abbasi committed
1101
            """Defines the key for the sorted method"""
1102
1103
1104
1105
1106
1107
1108
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

Baber Abbasi's avatar
Baber Abbasi committed
1109
            toks = req[1] + req[2]
1110
1111
            return -len(toks), tuple(toks)

Baber's avatar
types  
Baber committed
1112
        def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):
Baber Abbasi's avatar
Baber Abbasi committed
1113
1114
            """Defines the key to group and lookup one-token continuations"""
            # Use with group_by="contexts" (optional)"
Baber Abbasi's avatar
Baber Abbasi committed
1115
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
Baber Abbasi's avatar
Baber Abbasi committed
1116
1117
1118
1119
1120
1121
1122
1123
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by="contexts"
1124
            if self.backend == "causal" and self.logits_cache
Baber Abbasi's avatar
Baber Abbasi committed
1125
1126
1127
            else None,
            group_fn=_lookup_one_token_cont,
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
1128
1129
1130

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
Baber Abbasi's avatar
Baber Abbasi committed
1131
1132
1133
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
1134
1135
1136
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
Baber Abbasi's avatar
Baber Abbasi committed
1137
1138
1139
1140
            else 0
        )
        batch_fn = (
            self._batch_scheduler
1141
1142
1143
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
Baber Abbasi's avatar
Baber Abbasi committed
1144
            else None
1145
1146
        )

Baber Abbasi's avatar
Baber Abbasi committed
1147
        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1148
1149
1150
1151
1152
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
1153
        for chunk in chunks:
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
1173
                # how this all works (illustrated on a causal decoder-only setup):
1174
1175
1176
1177
1178
1179
1180
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
1181
                if self.backend == "causal":
1182
1183
                    total_length = len(context_enc) + len(continuation_enc)
                    if total_length > self.max_length + 1:
1184
                        eval_logger.warning(
1185
1186
1187
1188
                            f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
                            f"exceeds model's maximum length ({self.max_length}). "
                            f"Truncating {total_length - self.max_length + 1} tokens from the left."
                        )
1189
1190
1191
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
1192
1193
                        device=self.device,
                    )
1194
                    (inplen,) = inp.shape
1195
                elif self.backend == "seq2seq":
1196
1197
1198
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
1199
                        device=self.device,
1200
                    )
1201
                    (inplen,) = inp.shape
1202
1203
1204
1205

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

1206
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
1207
                        (continuation_enc)[-self.max_length :],
1208
1209
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
1210
                        dtype=torch.long,
1211
1212
                        device=self.device,
                    )
1213
1214
                    (contlen,) = cont.shape

1215
1216
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
1217
1218
1219
1220
1221
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
1222

haileyschoelkopf's avatar
haileyschoelkopf committed
1223
1224
1225
1226
1227
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
1228
1229
1230
1231

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
1232

1233
1234
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
1235
            if self.backend == "causal":
1236
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1237
1238
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
1239
            elif self.backend == "seq2seq":
1240
                # TODO: left-pad encoder inps and mask?
1241
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1242
1243
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
1244
                batched_conts = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1245
1246
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
1247
                batched_encoder_mask = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1248
1249
1250
1251
1252
1253
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
1254
1255

            multi_logits = F.log_softmax(
1256
1257
1258
                self._model_call(batched_inps, **call_kwargs),
                dim=-1,
                dtype=self.softmax_dtype,
1259
            )  # [batch, padding_length (inp or cont), vocab]
1260

Baber Abbasi's avatar
Baber Abbasi committed
1261
            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1262
1263
1264
1265
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
1266
                # take only logits in the continuation
1267
                # (discard context toks if decoder-only ; discard right-padding)
1268
1269
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
haileyschoelkopf's avatar
haileyschoelkopf committed
1270
                ctx_len = (
1271
                    inplen + (logits.shape[0] - padding_len_inp)
1272
                    if self.backend == "causal"
haileyschoelkopf's avatar
haileyschoelkopf committed
1273
1274
                    else None
                )
1275
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
1276
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
1277
1278
1279
1280

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

Baber Abbasi's avatar
Baber Abbasi committed
1281
1282
1283
1284
1285
                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
Baber's avatar
types  
Baber committed
1286
                for request_str, cont_toks, logits in re_ord.get_cache(  # noqa
Baber Abbasi's avatar
Baber Abbasi committed
1287
1288
1289
1290
1291
1292
1293
1294
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(0)  # [1, seq]
1295
1296
1297
1298
1299
1300
                    # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
                    # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
                    # by choosing key with longest cont if group_by="contexts".
                    max_equal = (
                        greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
                    ).all()
Baber Abbasi's avatar
Baber Abbasi committed
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    res.append(answer)

1313
1314
1315
1316
1317
1318
1319
                    if request_str is not None:
                        # special case: loglikelihood_rolling produces a number of loglikelihood requests
                        # all with cache key None. instead do add_partial on the per-example level
                        # in the loglikelihood_rolling() function for those.
                        self.cache_hook.add_partial(
                            "loglikelihood", request_str, answer
                        )
Baber Abbasi's avatar
Baber Abbasi committed
1320
                    pbar.update(1)
haileyschoelkopf's avatar
haileyschoelkopf committed
1321
1322

        pbar.close()
haileyschoelkopf's avatar
haileyschoelkopf committed
1323

1324
1325
        return re_ord.get_original(res)

1326
    def generate_until(
Baber's avatar
types  
Baber committed
1327
1328
        self, requests: list[Instance], disable_tqdm: bool = False
    ) -> list[str]:
Baber Abbasi's avatar
Baber Abbasi committed
1329
        res = []
1330

Baber's avatar
types  
Baber committed
1331
        def _collate(req: tuple[str, dict]):
Baber Abbasi's avatar
Baber Abbasi committed
1332
            """Defines the key for the sorted method"""
1333
1334
1335
1336
1337
1338
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
Baber Abbasi's avatar
Baber Abbasi committed
1339
1340
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]
1341

1342
1343
        pbar = tqdm(
            total=len(requests),
1344
            disable=(disable_tqdm or (self.rank != 0)),
1345
1346
            desc="Running generate_until requests",
        )
Baber Abbasi's avatar
Baber Abbasi committed
1347
        adaptive_batch_size = None
1348
1349
1350
1351
1352
1353
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
1354
        # for each different set of kwargs, we execute all requests, by batch.
Baber Abbasi's avatar
Baber Abbasi committed
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )
1367

Baber Abbasi's avatar
Baber Abbasi committed
1368
1369
1370
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
1371
1372
1373
1374
1375
1376
1377
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            [reg.args for reg in requests],
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
Baber Abbasi's avatar
Baber Abbasi committed
1378
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1379
        eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
Baber Abbasi's avatar
Baber Abbasi committed
1380
1381
1382
1383
1384
1385
1386
1387
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
1388
1389
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
Baber Abbasi's avatar
Baber Abbasi committed
1390
1391
            else:
                raise ValueError(
Baber Abbasi's avatar
Baber Abbasi committed
1392
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1393
                )
Baber's avatar
types  
Baber committed
1394
            if "max_gen_toks" in kwargs:
Baber Abbasi's avatar
Baber Abbasi committed
1395
1396
1397
1398
1399
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
1400
            if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
1401
1402
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
Baber Abbasi's avatar
Baber Abbasi committed
1403
1404
1405
                assert max_ctx_len > 0, (
                    f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
                )
1406
            elif self.backend == "seq2seq":
Baber Abbasi's avatar
Baber Abbasi committed
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)
1418

Baber Abbasi's avatar
Baber Abbasi committed
1419
1420
            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1421

Baber Abbasi's avatar
Baber Abbasi committed
1422
1423
1424
1425
1426
1427
1428
            # perform batched generation
            cont = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
1429

Baber Abbasi's avatar
Baber Abbasi committed
1430
1431
1432
            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
1433
                if self.backend == "causal":
Baber Abbasi's avatar
Baber Abbasi committed
1434
                    cont_toks = cont_toks[context_enc.shape[1] :]
1435

1436
1437
1438
1439
1440
1441
1442
1443
1444
                # Handle integer think_end_token: find last occurrence and strip tokens after it
                if isinstance(self.think_end_token, int):
                    think_token_indices = [
                        i
                        for i, token in enumerate(cont_toks)
                        if token == self.think_end_token
                    ]
                    if think_token_indices:
                        cont_toks = cont_toks[think_token_indices[-1] + 1 :]
1445

1446
                s = self.tok_decode(cont_toks)
Baber Abbasi's avatar
Baber Abbasi committed
1447

1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
                # Strip leading whitespace if we removed thinking tokens
                if isinstance(self.think_end_token, int):
                    s = s.lstrip()

                # Apply post-processing: remove stop sequences and string-based thinking tokens
                s = postprocess_generated_text(
                    generation=s,
                    stop=until,
                    think_end_token=self.think_end_token
                    if isinstance(self.think_end_token, str)
                    else None,
                )
Baber Abbasi's avatar
Baber Abbasi committed
1460
1461
1462
1463
1464
1465
                res.append(s)

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
1466

1467
        pbar.close()
1468

Baber Abbasi's avatar
Baber Abbasi committed
1469
        return res
1470

Baber Abbasi's avatar
Baber Abbasi committed
1471
    def apply_chat_template(
Baber's avatar
types  
Baber committed
1472
        self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
Baber Abbasi's avatar
Baber Abbasi committed
1473
    ) -> str:
KonradSzafer's avatar
KonradSzafer committed
1474
1475
1476
        """
        Method to apply a chat template to a list of chat history between user and model.
        """
1477
1478
        try:
            chat_templated = self.tokenizer.apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
1479
1480
1481
1482
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
1483
1484
1485
1486
1487
1488
1489
            )
        except jinja2.exceptions.TemplateError:
            eval_logger.warning(
                "Failed to apply chat template. removing the system role in chat history."
            )
            chat_history = [msg for msg in chat_history if msg["role"] != "system"]
            chat_templated = self.tokenizer.apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
1490
1491
1492
1493
                chat_history,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
                continue_final_message=not add_generation_prompt,
1494
1495
1496
            )

        return chat_templated
KonradSzafer's avatar
KonradSzafer committed
1497

1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
    def get_model_info(self) -> dict:
        """
        Method to get Hugging Face model information for experiment reproducibility.
        """

        def get_model_num_params(model) -> int:
            if hasattr(model, "num_parameters"):
                return model.num_parameters()
            if hasattr(model, "parameters"):
                return sum(p.numel() for p in model.parameters())
            else:
                return -1

        def get_model_dtype(model) -> str:
            if hasattr(model, "dtype"):
                return model.dtype
            else:
                return ""

        def get_model_sha(pretrained: str, revision: str) -> str:
            try:
                model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
                return model_info.sha
            except Exception as e:
Baber Abbasi's avatar
Baber Abbasi committed
1522
                eval_logger.debug(
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
                    f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
                )
                return ""

        model_info = {
            "model_num_parameters": get_model_num_params(self._model),
            "model_dtype": get_model_dtype(self._model),
            "model_revision": self.revision,
            "model_sha": get_model_sha(self.pretrained, self.revision),
        }
        if self.peft:
            model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
        if self.delta:
            model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
        return model_info