huggingface.py 54.5 KB
Newer Older
1
import copy
2
import os
Jeevan's avatar
Jeevan committed
3
from datetime import timedelta
4
5
6
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union

7
import torch
8
import torch.nn.functional as F
9
import transformers
Jeevan's avatar
Jeevan committed
10
11
12
13
14
15
from accelerate import (
    Accelerator,
    DistributedType,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)
16
from huggingface_hub import HfApi
17
18
19
20
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
21
22
23
24
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
25
26

from lm_eval import utils
baberabb's avatar
baberabb committed
27
from lm_eval.api.instance import Instance
28
from lm_eval.api.model import TemplateLM
29
from lm_eval.api.registry import register_model
30
31
32
33
34
35
36
from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
    get_dtype,
    pad_and_concat,
    stop_sequences_criteria,
)
37

38

39
eval_logger = utils.eval_logger
40

lintangsutawika's avatar
lintangsutawika committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
65
66


67
@register_model("hf-auto", "hf", "huggingface")
68
class HFLM(TemplateLM):
69
70
71
72
73
74
75
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

76
    AUTO_MODEL_CLASS = None
77
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
78

79
80
    def __init__(
        self,
81
        pretrained: Union[str, transformers.PreTrainedModel],
Baber Abbasi's avatar
Baber Abbasi committed
82
83
        backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
84
85
        revision: Optional[str] = "main",
        subfolder: Optional[str] = None,
86
87
88
89
90
91
92
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ] = None,
lintangsutawika's avatar
lintangsutawika committed
93
        truncation: Optional[bool] = False,
Baber Abbasi's avatar
Baber Abbasi committed
94
        logits_cache: bool = True,
95
96
        max_length: Optional[int] = None,
        device: Optional[str] = "cuda",
97
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Benjamin Fattori's avatar
Benjamin Fattori committed
98
99
        batch_size: Optional[Union[int, str]] = 1,
        max_batch_size: Optional[int] = 64,
100
        trust_remote_code: Optional[bool] = False,
haileyschoelkopf's avatar
haileyschoelkopf committed
101
        use_fast_tokenizer: Optional[bool] = True,
102
        add_bos_token: Optional[bool] = False,
103
        prefix_token_id: Optional[int] = None,
104
        # arguments used for splitting a model across GPUs naively.
105
106
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
107
108
109
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
110
        offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
111
        # PEFT, delta weights and quantization options
112
        peft: Optional[str] = None,
113
        delta: Optional[str] = None,
114
115
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
Ethan Smith's avatar
Ethan Smith committed
116
    ) -> None:
117
118
        super().__init__()

119
120
121
122
        # optionally: take in an already-initialized transformers.PreTrainedModel
        if not isinstance(pretrained, str):
            eval_logger.warning(
                "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
123
            )
124
            assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
125
126
127
            self._model = pretrained
            self._device = self._model.device
            self._config = self._model.config
Baber Abbasi's avatar
Baber Abbasi committed
128
            gpus = 0
129
130
131
132
133
134

            if tokenizer:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
135
            else:
136
137
138
139
140
141
142
                # Get tokenizer
                model_name = self._model.name_or_path
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_name,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    use_fast=use_fast_tokenizer,
143
                )
144

145
        else:
146
147
148
149
150
            assert isinstance(device, str)
            assert isinstance(pretrained, str)
            assert isinstance(batch_size, (int, str))

            gpus = torch.cuda.device_count()
Jeevan's avatar
Jeevan committed
151
152
            accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
            accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
153
154
            if accelerator.num_processes > 1:
                self.accelerator = accelerator
155
156
157
158
159
160
161

            if not (parallelize or accelerator.num_processes > 1):
                # use user-passed device
                device_list = set(
                    ["cuda", "cpu"]
                    + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
                    + ["mps", "mps:0"]
162
                )
163
                if device and device in device_list:
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                    self._device = torch.device(device)
                    eval_logger.info(f"Using device '{device}'")
                    if device in ("mps", "mps:0") and version.parse(
                        torch.__version__
                    ) < version.parse("2.1"):
                        raise RuntimeError(
                            f"mps requires torch >= 2.1. You have {torch.__version__}"
                        )
                else:
                    eval_logger.info("Device not specified")
                    eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                    self._device = (
                        torch.device("cuda")
                        if torch.cuda.is_available()
                        else torch.device("cpu")
                    )
            else:
                if device != "cuda":
                    eval_logger.info(
                        f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                    )
                # TODO: include in warning that `load_in_8bit` etc. affect this too
186
                self._device = torch.device(device)
187

188
189
            # TODO: update this to be less of a hack once subfolder is fixed in HF
            revision = revision + ("/" + subfolder if subfolder is not None else "")
190

191
            self._get_config(
192
193
194
195
196
                pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
            )

197
198
199
200
        # determine which of 'causal' and 'seq2seq' backends to use
        self._get_backend(
            config=self.config, backend=backend, trust_remote_code=trust_remote_code
        )
201

202
203
204
205
206
207
208
209
210
211
212
213
214
        # if we passed `pretrained` as a string, initialize our model now
        if isinstance(pretrained, str):
            self._create_model(
                pretrained=pretrained,
                revision=revision,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                parallelize=parallelize,
                device_map_option=device_map_option,
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                peft=peft,
215
                delta=delta,
216
217
                autogptq=autogptq,
                **kwargs,
218
219
            )

220
        # access self._model through self.model property outside this method
221
222
223
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            self.model.tie_weights()
haileyschoelkopf's avatar
haileyschoelkopf committed
224

225
        if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"):
226
227
            # TODO: can remove this whole snippet except in the mps case, perhaps?
            if not (parallelize or autogptq or hasattr(self, "accelerator")):
228
229
230
231
232
233
                # place model onto device requested manually,
                # if not using HF Accelerate or device_map
                # or any other option that preloads model onto device
                try:
                    self.model.to(self.device)
                except ValueError:
234
235
                    eval_logger.debug(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
236
237
238
239
240
                    )

        self._create_tokenizer(
            pretrained,
            tokenizer,
241
            revision=revision,
242
            trust_remote_code=trust_remote_code,
243
            use_fast_tokenizer=use_fast_tokenizer,
244
245
        )

lintangsutawika's avatar
lintangsutawika committed
246
        self.truncation = truncation
Baber Abbasi's avatar
Baber Abbasi committed
247
        self.logits_cache = logits_cache
248
        self.vocab_size = self.tokenizer.vocab_size
249
250
251
252
253
254
255
256
        # select (or create) a pad token to use
        if self.tokenizer.pad_token:
            pass
        elif self.tokenizer.unk_token:
            self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
        elif self.tokenizer.eos_token:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        else:
257
            if getattr(self.config, "model_type", None) == "qwen":
258
259
                # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
                self.tokenizer.pad_token = "<|endoftext|>"
260
261
262
263
264
265
266
267
268
269
            elif (
                self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
                or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
            ):
                # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
                # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
                # ---
                # Note that the world tokenizer class name, might change in the future for the final huggingface merge
                # https://github.com/huggingface/transformers/pull/26963
                assert self.tokenizer.pad_token_id == 0
270
271
            else:
                self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
272

273
274
        # TODO: override this for Gemma
        self.add_bos_token = add_bos_token
275
276
        if getattr(self.config, "model_type", None) == "gemma":
            self.add_bos_token = True
277
            eval_logger.info(
278
                f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
279
280
            )

281
        self._max_length = max_length
282
283
284
285
        self.pretrained = pretrained
        self.delta = delta
        self.peft = peft
        self.revision = revision
Benjamin Fattori's avatar
Benjamin Fattori committed
286
287
288
289
290
291
292
293
294
295
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
        if isinstance(pretrained, str):
            # multigpu data-parallel support when launched with accelerate
            if gpus > 1:
                if parallelize:
                    if accelerator.num_processes > 1:
                        raise RuntimeError(
                            "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
                        )
                    else:
                        pass
                elif accelerator.num_processes == 1:
                    # if we aren't launching via accelerate, ditch
                    self._rank = 0
                    self._world_size = 1
311
                else:
312
313
314
315
316
317
318
                    if gpus > accelerator.num_processes:
                        eval_logger.warning(
                            "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                            "If you would like to use data parallelism, please launch the script "
                            "with 'accelerate launch *script*'. "
                            f"Current run will proceed with {accelerator.num_processes} devices."
                        )
319
320
321
322
323
324
325
                    assert (
                        accelerator.distributed_type
                        in [
                            DistributedType.FSDP,
                            DistributedType.MULTI_GPU,
                        ]
                    ), "Unsupported distributed type provided. Only DDP and FSDP are supported."
326
327
328
329
330
331
332
333
                    if accelerator.distributed_type == DistributedType.FSDP:
                        self._model = accelerator.prepare(self.model)
                    else:
                        self._model = accelerator.prepare_model(
                            self.model, evaluation_mode=True
                        )
                    self._device = torch.device(
                        f"cuda:{accelerator.local_process_index}"
334
                    )
335
                    self.accelerator = accelerator
336

337
338
                    if self.accelerator.is_local_main_process:
                        eval_logger.info(f"Using {gpus} devices with data parallelism")
339

340
341
342
343
344
345
346
347
348
                    self._rank = self.accelerator.local_process_index
                    self._world_size = self.accelerator.num_processes
        else:
            # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
            eval_logger.warning(
                "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
            )
            self._rank = 0
            self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
349

350
        self.custom_prefix_token_id = prefix_token_id
351
352
353
354
        if prefix_token_id is not None:
            eval_logger.info(
                f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
            )
355

356
357
358
359
360
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

361
362
363
364
365
366
367
368
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

369
370
371
372
373
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

374
375
376
377
378
379
380
381
382
    @property
    def prefix_token_id(self):
        # it is used as prefix for loglikelihood
        if self.custom_prefix_token_id is not None:
            return self.custom_prefix_token_id
        if self.tokenizer.bos_token_id is not None:
            return self.tokenizer.bos_token_id
        return self.tokenizer.eos_token_id

383
384
    @property
    def max_length(self):
385
386
387
388
389
390
391
392
393
394
395
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
396

397
    @property
Ethan Smith's avatar
Ethan Smith committed
398
    def max_gen_toks(self) -> int:
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

417
418
    def _get_backend(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
419
        config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        trust_remote_code: Optional[bool] = False,
    ) -> None:
        """
        Helper method during initialization.
        Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
        model type to be used.
        """
        assert backend in ["default", "causal", "seq2seq"]

        if backend != "default":
            # if we've settled on non-default backend, use that manually
            if backend == "causal":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            elif backend == "seq2seq":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            eval_logger.info(
                f"Overrode HF model backend type, and using type '{backend}'"
            )
        else:
            # determine and use the default HF backend for this model, based on its config + metadata.
            if (
                getattr(config, "model_type")
                in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
            ):
                # first check if model type is listed under seq2seq models, since some
                # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
                # these special cases should be treated as seq2seq models.
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            elif (
                getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
            ):
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            else:
                if not trust_remote_code:
                    eval_logger.warning(
                        "HF model type is neither marked as CausalLM or Seq2SeqLM. \
                    This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
                    )
                # if model type is neither in HF transformers causal or seq2seq model registries
                # then we default to AutoModelForCausalLM
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
        return None

    def _get_config(
        self,
        pretrained: str,
        revision: str = "main",
        trust_remote_code: bool = False,
    ) -> None:
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

    def _create_model(
        self,
        pretrained: str,
        revision: Optional[str] = "main",
        dtype: Optional[Union[str, torch.dtype]] = "auto",
        trust_remote_code: Optional[bool] = False,
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # (accelerate naive PP (device_map) options)
        parallelize: Optional[bool] = False,
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
495
        # PEFT, delta weights and quantization options
496
        peft: Optional[str] = None,
497
        delta: Optional[str] = None,
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
    ) -> None:
        """
        Initializes an HF or HF-compatible PreTrainedModel from scratch
        inside HFLM, using the kwargs passed into self.__init__().

        Also handles functionality such as AutoGPTQ usage and PEFT wrapping.

        For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
        (such as PyTorch models that are nearly, but not quite, fully mirroring
        HF's public interface relied on in this HFLM class)
        please consider subclassing HFLM and overriding this and other methods as needed.
        """

        model_kwargs = kwargs if kwargs else {}

        if parallelize:
            model_kwargs.update(
                _get_accelerate_args(
518
                    device_map_option,  # TODO: phase out device_map_option?
519
520
521
522
523
                    max_memory_per_gpu,
                    max_cpu_memory,
                    offload_folder,
                )
            )
524
525
526
527
528
529
530
531
532
533
534
535
        elif "device_map" not in model_kwargs:
            # set a device_map to initialize model on the right GPU.
            # this is needed because it seems that the default behavior
            # for quantized models now seems to be device_map="auto"
            # which breaks data-parallel mode.
            if hasattr(self, "accelerator"):
                model_kwargs.update(
                    {"device_map": {"": f"cuda:{self.accelerator.local_process_index}"}}
                )
            else:
                model_kwargs.update({"device_map": {"": str(self.device)}})

536
537
538
539
540
541
542
543
        if not autogptq:
            if model_kwargs.get("load_in_4bit", None):
                assert (
                    transformers.__version__ >= "4.30.0"
                ), "load_in_4bit requires transformers >= 4.30.0"
            if transformers.__version__ >= "4.30.0":
                if model_kwargs.get("load_in_4bit", None):
                    if model_kwargs.get("bnb_4bit_compute_dtype", None):
544
                        model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
545
546
547
548
549
                            model_kwargs["bnb_4bit_compute_dtype"]
                        )
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
550
                torch_dtype=get_dtype(dtype),
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
        else:
            try:
                from auto_gptq import AutoGPTQForCausalLM
            except ModuleNotFoundError:
                raise Exception(
                    "Tried to load auto_gptq, but auto-gptq is not installed ",
                    "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
                )

            self._model = AutoGPTQForCausalLM.from_quantized(
                pretrained,
                trust_remote_code=trust_remote_code,
                model_basename=None if autogptq is True else Path(autogptq).stem,
                use_safetensors=True
                if autogptq is True
                else autogptq.endswith(".safetensors"),
                **model_kwargs,
            )

573
574
575
576
577
        if peft and delta:
            raise ValueError(
                "Cannot use both 'peft' and 'delta' options at the same time."
            )

578
579
        if peft:
            if model_kwargs.get("load_in_4bit", None):
WoosungMyung's avatar
WoosungMyung committed
580
581
                if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
                    raise AssertionError("load_in_4bit requires peft >= 0.4.0")
582
583
584
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        elif delta:
            if autogptq:
                eval_logger.warning(
                    "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
                )
            _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
                delta,
                revision=revision,
                torch_dtype=get_dtype(dtype),
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
            for name, param in self._model.state_dict().items():
                try:
                    param.data += _model_delta.state_dict()[name]
                except KeyError:
                    raise KeyError(f"Delta model is missing weights for layer: {name}")
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to add delta weights to layer {name}. Error: {e}"
                    )

            del _model_delta
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

        return None

    def _create_tokenizer(
        self,
        pretrained: Union[str, transformers.PreTrainedModel],
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ],
        revision: Optional[str] = "main",
        trust_remote_code: Optional[bool] = False,
        use_fast_tokenizer: Optional[bool] = True,
    ) -> None:
        """
        Helper method during initialization.

        Create a tokenizer object corresponding to the correct
        tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
        """

        if tokenizer:
            if isinstance(tokenizer, str):
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    tokenizer,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    use_fast=use_fast_tokenizer,
                )
            else:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
        else:
            # Get tokenizer based on 'pretrained'
            if isinstance(pretrained, str):
                model_name = pretrained
            else:
                # get the HF hub name via accessor on model
                model_name = self.model.name_or_path
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                model_name,
                revision=revision,
                trust_remote_code=trust_remote_code,
                use_fast=use_fast_tokenizer,
            )
        return None

Ethan Smith's avatar
Ethan Smith committed
660
    def _detect_batch_size(self, requests=None, pos: int = 0):
Benjamin Fattori's avatar
Benjamin Fattori committed
661
662
663
664
665
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
666
667
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
Benjamin Fattori's avatar
Benjamin Fattori committed
668
669
        else:
            max_length = self.max_length
670
671
            max_context_enc = max_length
            max_cont_enc = max_length
lintangsutawika's avatar
lintangsutawika committed
672

Benjamin Fattori's avatar
Benjamin Fattori committed
673
674
675
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
676
677
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                length = max(max_context_enc, max_cont_enc)
lintangsutawika's avatar
lintangsutawika committed
678
679
680
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
681
682
                test_batch = torch.ones((batch_size, length), device=self.device).long()
                call_kwargs = {
lintangsutawika's avatar
lintangsutawika committed
683
684
685
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
686
687
            else:
                call_kwargs = {}
lintangsutawika's avatar
lintangsutawika committed
688
689
690
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
Benjamin Fattori's avatar
Benjamin Fattori committed
691
            for _ in range(5):
692
                out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)  # noqa: F841
lintangsutawika's avatar
lintangsutawika committed
693

Benjamin Fattori's avatar
Benjamin Fattori committed
694
695
            return batch_size

696
697
698
699
700
701
702
        try:
            batch_size = forward_batch()
        except RuntimeError as e:
            if "No executable batch size found" in str(e):
                batch_size = 1
            else:
                raise
Benjamin Fattori's avatar
Benjamin Fattori committed
703

704
705
706
707
708
709
710
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
711
            clear_torch_cache()
712
713
            return batch_size

714
        clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
715
716
        return batch_size

baberabb's avatar
baberabb committed
717
718
719
    def tok_encode(
        self, string: str, left_truncate_len=None, add_special_tokens=None
    ) -> List[int]:
haileyschoelkopf's avatar
haileyschoelkopf committed
720
        """ """
Lintang Sutawika's avatar
Lintang Sutawika committed
721
722
723
724
725
        # default for None - empty dict, use predefined tokenizer param
        # used for all models except for CausalLM or predefined value
        special_tokens_kwargs = {}

        # by default for CausalLM - false or self.add_bos_token is set
726
727
        if add_special_tokens is None:
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
Lintang Sutawika's avatar
Lintang Sutawika committed
728
729
730
731
732
733
                special_tokens_kwargs = {
                    "add_special_tokens": False or self.add_bos_token
                }
        # otherwise the method explicitly defines the value
        else:
            special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
734

Lintang Sutawika's avatar
Lintang Sutawika committed
735
        encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
haileyschoelkopf's avatar
haileyschoelkopf committed
736

737
738
739
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
740

741
742
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
743
    def tok_batch_encode(
lintangsutawika's avatar
lintangsutawika committed
744
745
        self,
        strings: List[str],
lintangsutawika's avatar
lintangsutawika committed
746
        padding_side: str = "left",
747
748
        left_truncate_len: int = None,
        truncation: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
749
    ) -> Tuple[torch.Tensor, torch.Tensor]:
haileyschoelkopf's avatar
haileyschoelkopf committed
750
751
752
753
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

Lintang Sutawika's avatar
Lintang Sutawika committed
754
        add_special_tokens = {}
haileyschoelkopf's avatar
haileyschoelkopf committed
755
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
Lintang Sutawika's avatar
Lintang Sutawika committed
756
            add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
haileyschoelkopf's avatar
haileyschoelkopf committed
757
758
759

        encoding = self.tokenizer(
            strings,
lintangsutawika's avatar
lintangsutawika committed
760
            truncation=truncation,
haileyschoelkopf's avatar
haileyschoelkopf committed
761
762
            padding="longest",
            return_tensors="pt",
Lintang Sutawika's avatar
Lintang Sutawika committed
763
            **add_special_tokens,
haileyschoelkopf's avatar
haileyschoelkopf committed
764
765
766
767
768
769
770
771
772
773
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

Lintang Sutawika's avatar
Lintang Sutawika committed
774
775
    def tok_decode(self, tokens, skip_special_tokens=True):
        return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
776
777
778

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
779
        :param inps: torch.Tensor
780
781
782
783
784
785
786
787
788
789
790
791
792
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
793
794
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
795
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
796
797
798
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
799
800
801
802
803
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
804
        # temperature = 0.0 if not set
805
806
807
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
Baber Abbasi's avatar
Baber Abbasi committed
808
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
809
        do_sample = generation_kwargs.get("do_sample", None)
810
811
812
813
814

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

Baber Abbasi's avatar
Baber Abbasi committed
815
816
        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
817
818
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
819
            self.tokenizer, stop, context.shape[1], context.shape[0]
820
        )
821
        return self.model.generate(
822
            input_ids=context,
823
824
            max_length=max_length,
            stopping_criteria=stopping_criteria,
825
            pad_token_id=self.tokenizer.pad_token_id,
826
827
828
            use_cache=True,
            **generation_kwargs,
        )
829

Baber Abbasi's avatar
Baber Abbasi committed
830
831
832
    def _select_cont_toks(
        self, logits: torch.Tensor, contlen: int = None, inplen: int = None
    ) -> torch.Tensor:
833
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
834
835
836
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
837
838
839
840
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
841
842
843
844
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
845
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
846
847
            logits = logits[:contlen]

848
849
        return logits

850
851
852
    def loglikelihood_rolling(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
853
        loglikelihoods = []
Benjamin Fattori's avatar
Benjamin Fattori committed
854
855
856
857
858
859
860
861
862

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

863
864
865
        for (string,) in tqdm(
            [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
        ):
866
867
868
869
870
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
871
                        prefix_token=self.prefix_token_id,
872
873
874
875
876
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
877
878

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
Baber Abbasi's avatar
Baber Abbasi committed
894
                requests=rolling_token_windows,
lintangsutawika's avatar
lintangsutawika committed
895
896
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
897
898
899
900
901
902
903
904
905
906
907
908
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
909

910
911
912
913
914
915
916
917
918
919
920
921
922
    def _batch_scheduler(self, pos, n_reordered_requests):
        sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
        if sched in self.batch_sizes:
            return self.batch_sizes[sched]
        if (len(self.batch_sizes) > 1) and (
            self.batch_sizes[sched - 1] == self.max_batch_size
        ):
            # if previous batch size is already maximal, skip recomputation
            self.batch_sizes[sched] = self.max_batch_size
            return self.batch_sizes[sched]
        print(
            f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
        )
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
923
        self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
924
925
        print(f"Determined largest batch size: {self.batch_sizes[sched]}")
        return self.batch_sizes[sched]
926

Ethan Smith's avatar
Ethan Smith committed
927
    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
928
929
930
931
932
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
        override_bs: int = None,
    ) -> List[Tuple[float, bool]]:
933
934
935
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

Baber Abbasi's avatar
Baber Abbasi committed
936
        def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
Baber Abbasi's avatar
Baber Abbasi committed
937
            """Defines the key for the sorted method"""
938
939
940
941
942
943
944
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

Baber Abbasi's avatar
Baber Abbasi committed
945
            toks = req[1] + req[2]
946
947
            return -len(toks), tuple(toks)

Baber Abbasi's avatar
Baber Abbasi committed
948
949
950
        def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key to group and lookup one-token continuations"""
            # Use with group_by="contexts" (optional)"
Baber Abbasi's avatar
Baber Abbasi committed
951
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
Baber Abbasi's avatar
Baber Abbasi committed
952
953
954
955
956
957
958
959
960
961
962
963
964
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by="contexts"
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
            and self.logits_cache
            else None,
            group_fn=_lookup_one_token_cont,
        )
Benjamin Fattori's avatar
Benjamin Fattori committed
965
966
967

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
Baber Abbasi's avatar
Baber Abbasi committed
968
969
970
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
971
972
973
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
Baber Abbasi's avatar
Baber Abbasi committed
974
975
976
977
            else 0
        )
        batch_fn = (
            self._batch_scheduler
978
979
980
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
Baber Abbasi's avatar
Baber Abbasi committed
981
            else None
982
983
        )

Baber Abbasi's avatar
Baber Abbasi committed
984
        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
985
986
987
988
989
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
990
        for chunk in chunks:
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
1010
                # how this all works (illustrated on a causal decoder-only setup):
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
1022
1023
                        device=self.device,
                    )
1024
1025
1026
1027
1028
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
1029
                        device=self.device,
1030
                    )
1031
                    (inplen,) = inp.shape
1032
1033
1034
1035

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

1036
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
1037
                        (continuation_enc)[-self.max_length :],
1038
1039
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
1040
                        dtype=torch.long,
1041
1042
                        device=self.device,
                    )
1043
1044
                    (contlen,) = cont.shape

1045
1046
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
1047
1048
1049
1050
1051
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
1052

haileyschoelkopf's avatar
haileyschoelkopf committed
1053
1054
1055
1056
1057
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
1058
1059
1060
1061

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
1062

1063
1064
1065
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1066
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1067
1068
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
1069
1070
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
1071
                batched_inps = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1072
1073
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
1074
                batched_conts = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1075
1076
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
1077
                batched_encoder_mask = pad_and_concat(
haileyschoelkopf's avatar
haileyschoelkopf committed
1078
1079
1080
1081
1082
1083
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
1084
1085
1086

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
1087
            )  # [batch, padding_length (inp or cont), vocab]
1088

Baber Abbasi's avatar
Baber Abbasi committed
1089
            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1090
1091
1092
1093
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
1094
                # take only logits in the continuation
1095
                # (discard context toks if decoder-only ; discard right-padding)
1096
1097
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
haileyschoelkopf's avatar
haileyschoelkopf committed
1098
                ctx_len = (
1099
                    inplen + (logits.shape[0] - padding_len_inp)
haileyschoelkopf's avatar
haileyschoelkopf committed
1100
1101
1102
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
1103
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
1104
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
1105
1106
1107
1108

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

Baber Abbasi's avatar
Baber Abbasi committed
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
                for request_str, cont_toks, logits in re_ord.get_cache(
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(0)  # [1, seq]
                    max_equal = (greedy_tokens == cont_toks).all()

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    res.append(answer)

                    self.cache_hook.add_partial("loglikelihood", request_str, answer)
                    pbar.update(1)
haileyschoelkopf's avatar
haileyschoelkopf committed
1138
1139

        pbar.close()
haileyschoelkopf's avatar
haileyschoelkopf committed
1140

1141
1142
        return re_ord.get_original(res)

1143
1144
1145
    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
Baber Abbasi's avatar
Baber Abbasi committed
1146
        res = []
1147

Baber Abbasi's avatar
Baber Abbasi committed
1148
        def _collate(req: Tuple[str, dict]):
Baber Abbasi's avatar
Baber Abbasi committed
1149
            """Defines the key for the sorted method"""
1150
1151
1152
1153
1154
1155
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
Baber Abbasi's avatar
Baber Abbasi committed
1156
1157
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]
1158

1159
1160
        pbar = tqdm(
            total=len(requests),
1161
            disable=(disable_tqdm or (self.rank != 0)),
1162
1163
            desc="Running generate_until requests",
        )
Baber Abbasi's avatar
Baber Abbasi committed
1164
        adaptive_batch_size = None
1165
1166
1167
1168
1169
1170
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
1171
        # for each different set of kwargs, we execute all requests, by batch.
Baber Abbasi's avatar
Baber Abbasi committed
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )
1184

Baber Abbasi's avatar
Baber Abbasi committed
1185
1186
1187
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
Baber Abbasi's avatar
Baber Abbasi committed
1188
1189
1190
1191
1192
1193
1194
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            [reg.args for reg in requests],
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
Baber Abbasi's avatar
Baber Abbasi committed
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
achervyakov's avatar
achervyakov committed
1208
                        until = [until]
Baber Abbasi's avatar
Baber Abbasi committed
1209
1210
1211
1212
1213
1214
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
Baber Abbasi's avatar
Baber Abbasi committed
1215
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1216
                )
1217
            # add EOS token to stop sequences
Lintang Sutawika's avatar
Lintang Sutawika committed
1218
            eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
Baber Abbasi's avatar
Baber Abbasi committed
1219
            if not until:
1220
1221
1222
                until = [eos]
            else:
                until.append(eos)
Baber Abbasi's avatar
Baber Abbasi committed
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)
1244

Baber Abbasi's avatar
Baber Abbasi committed
1245
1246
            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1247

Baber Abbasi's avatar
Baber Abbasi committed
1248
1249
1250
1251
1252
1253
1254
            # perform batched generation
            cont = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
1255

Baber Abbasi's avatar
Baber Abbasi committed
1256
1257
1258
1259
1260
            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    cont_toks = cont_toks[context_enc.shape[1] :]
1261

Baber Abbasi's avatar
Baber Abbasi committed
1262
                s = self.tok_decode(cont_toks)
1263

Baber Abbasi's avatar
Baber Abbasi committed
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                for term in until:
                    if len(term) > 0:
                        # ignore '' separator,
                        # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
                        s = s.split(term)[0]

                res.append(s)

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
1277

1278
        pbar.close()
1279

Baber Abbasi's avatar
Baber Abbasi committed
1280
        return res
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321

    def get_model_info(self) -> dict:
        """
        Method to get Hugging Face model information for experiment reproducibility.
        """

        def get_model_num_params(model) -> int:
            if hasattr(model, "num_parameters"):
                return model.num_parameters()
            if hasattr(model, "parameters"):
                return sum(p.numel() for p in model.parameters())
            else:
                return -1

        def get_model_dtype(model) -> str:
            if hasattr(model, "dtype"):
                return model.dtype
            else:
                return ""

        def get_model_sha(pretrained: str, revision: str) -> str:
            try:
                model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
                return model_info.sha
            except Exception as e:
                eval_logger.warn(
                    f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
                )
                return ""

        model_info = {
            "model_num_parameters": get_model_num_params(self._model),
            "model_dtype": get_model_dtype(self._model),
            "model_revision": self.revision,
            "model_sha": get_model_sha(self.pretrained, self.revision),
        }
        if self.peft:
            model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
        if self.delta:
            model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
        return model_info