huggingface.py 47.9 KB
Newer Older
1
import copy
2
import os
3
4
5
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union

6
import torch
7
import torch.nn.functional as F
8
import transformers
9
10
11
12
13
from accelerate import Accelerator, DistributedType, find_executable_batch_size
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
14
15
16
17
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
18
19

from lm_eval import utils
baberabb's avatar
baberabb committed
20
from lm_eval.api.instance import Instance
21
22
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
23
from lm_eval.utils import Collator, stop_sequences_criteria
24

25

26
eval_logger = utils.eval_logger
27

lintangsutawika's avatar
lintangsutawika committed
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
52
53


54
@register_model("hf-auto", "hf", "huggingface")
55
class HFLM(LM):
56
57
58
59
60
61
62
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

63
    AUTO_MODEL_CLASS = None
64
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
65

66
67
    def __init__(
        self,
68
69
70
71
        pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
        backend: Optional[
            Literal["default", "causal", "seq2seq"]
        ] = "default",  # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
72
73
        revision: Optional[str] = "main",
        subfolder: Optional[str] = None,
74
75
76
77
78
79
80
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ] = None,
lintangsutawika's avatar
lintangsutawika committed
81
        truncation: Optional[bool] = False,
82
83
        max_length: Optional[int] = None,
        device: Optional[str] = "cuda",
84
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Benjamin Fattori's avatar
Benjamin Fattori committed
85
86
        batch_size: Optional[Union[int, str]] = 1,
        max_batch_size: Optional[int] = 64,
87
        trust_remote_code: Optional[bool] = False,
haileyschoelkopf's avatar
haileyschoelkopf committed
88
        use_fast_tokenizer: Optional[bool] = True,
89
        # arguments used for splitting a model across GPUs naively.
90
91
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
92
93
94
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
95
        offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
96
97
        # PEFT and quantization options
        peft: Optional[str] = None,
98
99
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
Ethan Smith's avatar
Ethan Smith committed
100
    ) -> None:
101
102
        super().__init__()

103
104
105
106
        # optionally: take in an already-initialized transformers.PreTrainedModel
        if not isinstance(pretrained, str):
            eval_logger.warning(
                "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
107
            )
108
            assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
109
110
111
112
113
114
115
116
117
118
            self._model = pretrained
            self._device = self._model.device

            self._config = self._model.config

            if tokenizer:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
119
            else:
120
121
122
123
124
125
126
                # Get tokenizer
                model_name = self._model.name_or_path
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_name,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    use_fast=use_fast_tokenizer,
127
                )
128

129
        else:
130
131
132
133
134
135
            assert isinstance(device, str)
            assert isinstance(pretrained, str)
            assert isinstance(batch_size, (int, str))

            gpus = torch.cuda.device_count()
            accelerator = Accelerator()
136
137
            if accelerator.num_processes > 1:
                self.accelerator = accelerator
138
139
140
141
142
143
144

            if not (parallelize or accelerator.num_processes > 1):
                # use user-passed device
                device_list = set(
                    ["cuda", "cpu"]
                    + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
                    + ["mps", "mps:0"]
145
                )
146
                if device and device in device_list:
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                    self._device = torch.device(device)
                    eval_logger.info(f"Using device '{device}'")
                    if device in ("mps", "mps:0") and version.parse(
                        torch.__version__
                    ) < version.parse("2.1"):
                        raise RuntimeError(
                            f"mps requires torch >= 2.1. You have {torch.__version__}"
                        )
                else:
                    eval_logger.info("Device not specified")
                    eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                    self._device = (
                        torch.device("cuda")
                        if torch.cuda.is_available()
                        else torch.device("cpu")
                    )
            else:
                if device != "cuda":
                    eval_logger.info(
                        f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                    )
                # TODO: include in warning that `load_in_8bit` etc. affect this too
169
                self._device = torch.device(device)
170

171
172
            # TODO: update this to be less of a hack once subfolder is fixed in HF
            revision = revision + ("/" + subfolder if subfolder is not None else "")
173

174
            self._get_config(
175
176
177
178
179
                pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
            )

180
181
182
183
        # determine which of 'causal' and 'seq2seq' backends to use
        self._get_backend(
            config=self.config, backend=backend, trust_remote_code=trust_remote_code
        )
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        # if we passed `pretrained` as a string, initialize our model now
        if isinstance(pretrained, str):
            self._create_model(
                pretrained=pretrained,
                revision=revision,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                parallelize=parallelize,
                device_map_option=device_map_option,
                max_memory_per_gpu=max_memory_per_gpu,
                max_cpu_memory=max_cpu_memory,
                offload_folder=offload_folder,
                peft=peft,
                autogptq=autogptq,
                **kwargs,
200
201
            )

202
        # access self._model through self.model property outside this method
203
204
205
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            self.model.tie_weights()
haileyschoelkopf's avatar
haileyschoelkopf committed
206

207
        if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"):
208
209
            # TODO: can remove this whole snippet except in the mps case, perhaps?
            if not (parallelize or autogptq or hasattr(self, "accelerator")):
210
211
212
213
214
215
                # place model onto device requested manually,
                # if not using HF Accelerate or device_map
                # or any other option that preloads model onto device
                try:
                    self.model.to(self.device)
                except ValueError:
216
217
                    eval_logger.debug(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
218
219
220
221
222
                    )

        self._create_tokenizer(
            pretrained,
            tokenizer,
223
            revision=revision,
224
            trust_remote_code=trust_remote_code,
225
            use_fast_tokenizer=use_fast_tokenizer,
226
227
        )

lintangsutawika's avatar
lintangsutawika committed
228
229
        self.truncation = truncation

230
        self.vocab_size = self.tokenizer.vocab_size
231
232
233
234
235
236
237
238
        # select (or create) a pad token to use
        if self.tokenizer.pad_token:
            pass
        elif self.tokenizer.unk_token:
            self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
        elif self.tokenizer.eos_token:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        else:
239
            if self.config.model_type == "qwen":
240
241
242
243
                # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
                self.tokenizer.pad_token = "<|endoftext|>"
            else:
                self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
244

245
246
        self._max_length = max_length

Benjamin Fattori's avatar
Benjamin Fattori committed
247
248
249
250
251
252
253
254
255
256
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
        if isinstance(pretrained, str):
            # multigpu data-parallel support when launched with accelerate
            if gpus > 1:
                if parallelize:
                    if accelerator.num_processes > 1:
                        raise RuntimeError(
                            "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
                        )
                    else:
                        pass
                elif accelerator.num_processes == 1:
                    # if we aren't launching via accelerate, ditch
                    self._rank = 0
                    self._world_size = 1
272
                else:
273
274
275
276
277
278
279
                    if gpus > accelerator.num_processes:
                        eval_logger.warning(
                            "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                            "If you would like to use data parallelism, please launch the script "
                            "with 'accelerate launch *script*'. "
                            f"Current run will proceed with {accelerator.num_processes} devices."
                        )
280
281
282
283
284
285
286
                    assert (
                        accelerator.distributed_type
                        in [
                            DistributedType.FSDP,
                            DistributedType.MULTI_GPU,
                        ]
                    ), "Unsupported distributed type provided. Only DDP and FSDP are supported."
287
288
289
290
291
292
293
294
                    if accelerator.distributed_type == DistributedType.FSDP:
                        self._model = accelerator.prepare(self.model)
                    else:
                        self._model = accelerator.prepare_model(
                            self.model, evaluation_mode=True
                        )
                    self._device = torch.device(
                        f"cuda:{accelerator.local_process_index}"
295
                    )
296
                    self.accelerator = accelerator
297

298
299
                    if self.accelerator.is_local_main_process:
                        eval_logger.info(f"Using {gpus} devices with data parallelism")
300

301
302
303
304
305
306
307
308
309
                    self._rank = self.accelerator.local_process_index
                    self._world_size = self.accelerator.num_processes
        else:
            # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
            eval_logger.warning(
                "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
            )
            self._rank = 0
            self._world_size = 1
haileyschoelkopf's avatar
haileyschoelkopf committed
310

311
312
313
314
315
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

316
317
318
319
320
321
322
323
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

324
325
326
327
328
329
330
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
331
332
333
334
335
336
337
338
339
340
341
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
342

343
    @property
Ethan Smith's avatar
Ethan Smith committed
344
    def max_gen_toks(self) -> int:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def _get_backend(
        self,
        config: transformers.AutoConfig,
        backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        trust_remote_code: Optional[bool] = False,
    ) -> None:
        """
        Helper method during initialization.
        Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
        model type to be used.
        """
        assert backend in ["default", "causal", "seq2seq"]

        if backend != "default":
            # if we've settled on non-default backend, use that manually
            if backend == "causal":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            elif backend == "seq2seq":
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            eval_logger.info(
                f"Overrode HF model backend type, and using type '{backend}'"
            )
        else:
            # determine and use the default HF backend for this model, based on its config + metadata.
            if (
                getattr(config, "model_type")
                in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
            ):
                # first check if model type is listed under seq2seq models, since some
                # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
                # these special cases should be treated as seq2seq models.
                self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
            elif (
                getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
            ):
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
            else:
                if not trust_remote_code:
                    eval_logger.warning(
                        "HF model type is neither marked as CausalLM or Seq2SeqLM. \
                    This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
                    )
                # if model type is neither in HF transformers causal or seq2seq model registries
                # then we default to AutoModelForCausalLM
                self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
        return None

    def _get_config(
        self,
        pretrained: str,
        revision: str = "main",
        trust_remote_code: bool = False,
    ) -> None:
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

    def _create_model(
        self,
        pretrained: str,
        revision: Optional[str] = "main",
        dtype: Optional[Union[str, torch.dtype]] = "auto",
        trust_remote_code: Optional[bool] = False,
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # (accelerate naive PP (device_map) options)
        parallelize: Optional[bool] = False,
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
        # PEFT and quantization options
        peft: Optional[str] = None,
        autogptq: Optional[Union[bool, str]] = False,
        **kwargs,
    ) -> None:
        """
        Initializes an HF or HF-compatible PreTrainedModel from scratch
        inside HFLM, using the kwargs passed into self.__init__().

        Also handles functionality such as AutoGPTQ usage and PEFT wrapping.

        For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
        (such as PyTorch models that are nearly, but not quite, fully mirroring
        HF's public interface relied on in this HFLM class)
        please consider subclassing HFLM and overriding this and other methods as needed.
        """

        model_kwargs = kwargs if kwargs else {}

        if parallelize:
            model_kwargs.update(
                _get_accelerate_args(
463
                    device_map_option,  # TODO: phase out device_map_option?
464
465
466
467
468
                    max_memory_per_gpu,
                    max_cpu_memory,
                    offload_folder,
                )
            )
469
470
471
472
473
474
475
476
477
478
479
480
        elif "device_map" not in model_kwargs:
            # set a device_map to initialize model on the right GPU.
            # this is needed because it seems that the default behavior
            # for quantized models now seems to be device_map="auto"
            # which breaks data-parallel mode.
            if hasattr(self, "accelerator"):
                model_kwargs.update(
                    {"device_map": {"": f"cuda:{self.accelerator.local_process_index}"}}
                )
            else:
                model_kwargs.update({"device_map": {"": str(self.device)}})

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        if not autogptq:
            if model_kwargs.get("load_in_4bit", None):
                assert (
                    transformers.__version__ >= "4.30.0"
                ), "load_in_4bit requires transformers >= 4.30.0"
            if transformers.__version__ >= "4.30.0":
                if model_kwargs.get("load_in_4bit", None):
                    if model_kwargs.get("bnb_4bit_compute_dtype", None):
                        model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
                            model_kwargs["bnb_4bit_compute_dtype"]
                        )
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
                torch_dtype=utils.get_dtype(dtype),
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )
        else:
            try:
                from auto_gptq import AutoGPTQForCausalLM
            except ModuleNotFoundError:
                raise Exception(
                    "Tried to load auto_gptq, but auto-gptq is not installed ",
                    "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
                )

            self._model = AutoGPTQForCausalLM.from_quantized(
                pretrained,
                trust_remote_code=trust_remote_code,
                model_basename=None if autogptq is True else Path(autogptq).stem,
                use_safetensors=True
                if autogptq is True
                else autogptq.endswith(".safetensors"),
                **model_kwargs,
            )

        if peft:
            if model_kwargs.get("load_in_4bit", None):
                assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )

        return None

    def _create_tokenizer(
        self,
        pretrained: Union[str, transformers.PreTrainedModel],
        tokenizer: Optional[
            Union[
                str,
                transformers.PreTrainedTokenizer,
                transformers.PreTrainedTokenizerFast,
            ]
        ],
        revision: Optional[str] = "main",
        trust_remote_code: Optional[bool] = False,
        use_fast_tokenizer: Optional[bool] = True,
    ) -> None:
        """
        Helper method during initialization.

        Create a tokenizer object corresponding to the correct
        tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
        """

        if tokenizer:
            if isinstance(tokenizer, str):
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    tokenizer,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    use_fast=use_fast_tokenizer,
                )
            else:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
        else:
            # Get tokenizer based on 'pretrained'
            if isinstance(pretrained, str):
                model_name = pretrained
            else:
                # get the HF hub name via accessor on model
                model_name = self.model.name_or_path
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                model_name,
                revision=revision,
                trust_remote_code=trust_remote_code,
                use_fast=use_fast_tokenizer,
            )
        return None

Ethan Smith's avatar
Ethan Smith committed
576
    def _detect_batch_size(self, requests=None, pos: int = 0):
Benjamin Fattori's avatar
Benjamin Fattori committed
577
578
579
580
581
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
582
583
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
Benjamin Fattori's avatar
Benjamin Fattori committed
584
585
        else:
            max_length = self.max_length
lintangsutawika's avatar
lintangsutawika committed
586

Benjamin Fattori's avatar
Benjamin Fattori committed
587
588
589
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
590
591
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                length = max(max_context_enc, max_cont_enc)
lintangsutawika's avatar
lintangsutawika committed
592
593
594
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
595
596
                test_batch = torch.ones((batch_size, length), device=self.device).long()
                call_kwargs = {
lintangsutawika's avatar
lintangsutawika committed
597
598
599
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
600
601
            else:
                call_kwargs = {}
lintangsutawika's avatar
lintangsutawika committed
602
603
604
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
Benjamin Fattori's avatar
Benjamin Fattori committed
605
            for _ in range(5):
606
                out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)  # noqa: F841
lintangsutawika's avatar
lintangsutawika committed
607

Benjamin Fattori's avatar
Benjamin Fattori committed
608
609
610
611
            return batch_size

        batch_size = forward_batch()

612
613
614
615
616
617
618
619
620
621
622
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
            utils.clear_torch_cache()
            return batch_size

        utils.clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
623
624
        return batch_size

baberabb's avatar
baberabb committed
625
626
627
    def tok_encode(
        self, string: str, left_truncate_len=None, add_special_tokens=None
    ) -> List[int]:
haileyschoelkopf's avatar
haileyschoelkopf committed
628
        """ """
629
630
631
632
633
        if add_special_tokens is None:
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                add_special_tokens = False
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                add_special_tokens = True
634
635

        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
haileyschoelkopf's avatar
haileyschoelkopf committed
636

637
638
639
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
640

641
642
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
643
    def tok_batch_encode(
lintangsutawika's avatar
lintangsutawika committed
644
645
        self,
        strings: List[str],
lintangsutawika's avatar
lintangsutawika committed
646
        padding_side: str = "left",
647
648
        left_truncate_len: int = None,
        truncation: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
649
    ) -> Tuple[torch.Tensor, torch.Tensor]:
haileyschoelkopf's avatar
haileyschoelkopf committed
650
651
652
653
654
655
656
657
658
659
660
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer(
            strings,
lintangsutawika's avatar
lintangsutawika committed
661
            truncation=truncation,
haileyschoelkopf's avatar
haileyschoelkopf committed
662
663
664
665
666
667
668
669
670
671
672
673
674
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

675
676
677
678
679
680
681
682
    def tok_decode(self, tokens):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            return self.tokenizer.decode(tokens)
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            return self.tokenizer.decode(tokens, skip_special_tokens=True)

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
683
        :param inps: torch.Tensor
684
685
686
687
688
689
690
691
692
693
694
695
696
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
697
698
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
699
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
700
701
702
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
703
704
705
706
707
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
708
709
710
711
712
713
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
        do_sample = generation_kwargs.get("do_sample", None)
        if do_sample is False and "temperature" == 0.0:
            generation_kwargs.pop("temperature", 0.0)
714
715
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
716
            self.tokenizer, stop, context.shape[1], context.shape[0]
717
        )
718
        return self.model.generate(
719
            input_ids=context,
720
721
            max_length=max_length,
            stopping_criteria=stopping_criteria,
722
            pad_token_id=self.tokenizer.pad_token_id,
723
724
725
            use_cache=True,
            **generation_kwargs,
        )
726
727
728

    def _select_cont_toks(self, logits, contlen=None, inplen=None):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
729
730
731
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
732
733
734
735
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
736
737
738
739
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
740
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
741
742
            logits = logits[:contlen]

743
744
        return logits

baberabb's avatar
baberabb committed
745
746
747
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
748
749
750
751
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
752
753
754
755
756
757

        whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
        context_enc = self.tok_encode(context, add_special_tokens=False)

        # whole_enc = self.tok_encode(context + continuation)
        # context_enc = self.tok_encode(context, add_special_tokens=False)
758
759
760
761
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

baberabb's avatar
baberabb committed
762
    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
763
764
765
766
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
767
768
769
                context_enc, continuation_enc = (
                    [self.eot_token_id],
                    self.tok_encode(continuation),
770
                )
771
            else:
772
                context_enc, continuation_enc = self._encode_pair(context, continuation)
773
774
775
776
777

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

baberabb's avatar
baberabb committed
778
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
779
        loglikelihoods = []
Benjamin Fattori's avatar
Benjamin Fattori committed
780
781
782
783
784
785
786
787
788

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

789
790
791
792
793
794
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
haileyschoelkopf's avatar
haileyschoelkopf committed
795
                        prefix_token=self.eot_token_id,
796
797
798
799
800
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
801
802

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
lintangsutawika's avatar
lintangsutawika committed
818
819
820
                rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
821
822
823
824
825
826
827
828
829
830
831
832
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
833

834
835
836
837
838
839
840
841
842
843
844
845
846
    def _batch_scheduler(self, pos, n_reordered_requests):
        sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
        if sched in self.batch_sizes:
            return self.batch_sizes[sched]
        if (len(self.batch_sizes) > 1) and (
            self.batch_sizes[sched - 1] == self.max_batch_size
        ):
            # if previous batch size is already maximal, skip recomputation
            self.batch_sizes[sched] = self.max_batch_size
            return self.batch_sizes[sched]
        print(
            f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
        )
Zhiwei Zhuang's avatar
Zhiwei Zhuang committed
847
        self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
848
849
        print(f"Determined largest batch size: {self.batch_sizes[sched]}")
        return self.batch_sizes[sched]
850

Ethan Smith's avatar
Ethan Smith committed
851
    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
852
853
854
855
856
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
        override_bs: int = None,
    ) -> List[Tuple[float, bool]]:
857
858
859
860
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
Baber Abbasi's avatar
Baber Abbasi committed
861
            """Defines the key for the sorted method"""
862
863
864
865
866
867
868
869
870
871
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

Baber Abbasi's avatar
Baber Abbasi committed
872
        re_ord = Collator(requests, sort_fn=_collate)
Benjamin Fattori's avatar
Benjamin Fattori committed
873
874
875

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
Baber Abbasi's avatar
Baber Abbasi committed
876
877
878
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
879
880
881
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
Baber Abbasi's avatar
Baber Abbasi committed
882
883
884
885
            else 0
        )
        batch_fn = (
            self._batch_scheduler
886
887
888
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
Baber Abbasi's avatar
Baber Abbasi committed
889
            else None
890
891
        )

Baber Abbasi's avatar
Baber Abbasi committed
892
        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
haileyschoelkopf's avatar
haileyschoelkopf committed
893
        pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
haileyschoelkopf's avatar
haileyschoelkopf committed
894
        for chunk in chunks:
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
914
                # how this all works (illustrated on a causal decoder-only setup):
915
916
917
918
919
920
921
922
923
924
925
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
926
927
                        device=self.device,
                    )
928
929
930
931
932
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
933
                        device=self.device,
934
                    )
935
                    (inplen,) = inp.shape
936
937
938
939

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

940
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
941
                        (continuation_enc)[-self.max_length :],
942
943
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
944
                        dtype=torch.long,
945
946
                        device=self.device,
                    )
947
948
                    (contlen,) = cont.shape

949
950
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
951
952
953
954
955
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
956

haileyschoelkopf's avatar
haileyschoelkopf committed
957
958
959
960
961
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
962
963
964
965

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
966

967
968
969
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
970
971
972
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
973
974
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
haileyschoelkopf's avatar
haileyschoelkopf committed
975
976
977
978
979
980
981
982
983
984
985
986
987
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = utils.pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = utils.pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
988
989
990

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
991
            )  # [batch, padding_length (inp or cont), vocab]
992
993
994
995
996
997

            for (cache_key, _, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
998
                # take only logits in the continuation
999
                # (discard context toks if decoder-only ; discard right-padding)
1000
1001
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
haileyschoelkopf's avatar
haileyschoelkopf committed
1002
                ctx_len = (
1003
                    inplen + (logits.shape[0] - padding_len_inp)
haileyschoelkopf's avatar
haileyschoelkopf committed
1004
1005
1006
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
1007
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
1008
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
1009
1010
1011

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
1012
1013
                cont_toks = torch.tensor(
                    cont_toks, dtype=torch.long, device=self.device
1014
                ).unsqueeze(0)  # [1, seq]
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

haileyschoelkopf's avatar
haileyschoelkopf committed
1028
                self.cache_hook.add_partial("loglikelihood", cache_key, answer)
haileyschoelkopf's avatar
haileyschoelkopf committed
1029
1030
1031
                pbar.update(1)

        pbar.close()
haileyschoelkopf's avatar
haileyschoelkopf committed
1032

1033
1034
        return re_ord.get_original(res)

baberabb's avatar
baberabb committed
1035
    def generate_until(self, requests: List[Instance]) -> List[str]:
Baber Abbasi's avatar
Baber Abbasi committed
1036
        res = []
1037
1038

        def _collate(x):
Baber Abbasi's avatar
Baber Abbasi committed
1039
            """Defines the key for the sorted method"""
1040
1041
1042
1043
1044
1045
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
1046
            toks = self.tok_encode(x[0])
haileyschoelkopf's avatar
haileyschoelkopf committed
1047
            return -len(toks), x[0]
1048

1049
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
1050
1051
1052
1053
1054
1055
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
1056
        # for each different set of kwargs, we execute all requests, by batch.
Baber Abbasi's avatar
Baber Abbasi committed
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )
1069

Baber Abbasi's avatar
Baber Abbasi committed
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        re_ords = Collator([reg.args for reg in requests], _collate, grouping=True)
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {kwargs}"
1095
                )
Baber Abbasi's avatar
Baber Abbasi committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
            if not until:
                until = [self.tok_decode(self.eot_token_id)]
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)
1119

Baber Abbasi's avatar
Baber Abbasi committed
1120
1121
            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1122

Baber Abbasi's avatar
Baber Abbasi committed
1123
1124
1125
1126
1127
1128
1129
            # perform batched generation
            cont = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
1130

Baber Abbasi's avatar
Baber Abbasi committed
1131
1132
1133
1134
1135
            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    cont_toks = cont_toks[context_enc.shape[1] :]
1136

Baber Abbasi's avatar
Baber Abbasi committed
1137
                s = self.tok_decode(cont_toks)
1138

Baber Abbasi's avatar
Baber Abbasi committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
                # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                for term in until:
                    if len(term) > 0:
                        # ignore '' separator,
                        # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
                        s = s.split(term)[0]

                res.append(s)

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
1152

1153
        pbar.close()
1154

Baber Abbasi's avatar
Baber Abbasi committed
1155
        return res