huggingface.py 24.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

import copy
from tqdm import tqdm

import torch.nn.functional as F

from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria

from accelerate import Accelerator
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Optional, Union


def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
44
45


46
@register_model("hf-auto", "hf", "huggingface")
47
class HFLM(LM):
48
49
50
51
52
53
54
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

55
    AUTO_MODEL_CLASS = None
56
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
57

58
59
60
61
62
63
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        low_cpu_mem_usage=None,
64
        max_length=None,
65
66
67
        subfolder=None,
        tokenizer=None,
        batch_size=1,
68
69
70
71
72
73
74
        dtype: Optional[Union[str, torch.dtype]] = "auto",
        # arguments used for splitting a model across GPUs naively.
        use_accelerate: Optional[bool] = False,
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
75
76
77
78
79
80
81
82
    ):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

        gpus = torch.cuda.device_count()
haileyschoelkopf's avatar
haileyschoelkopf committed
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        if gpus <= 1:
            if device:
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
                eval_logger.info(f"Using device '{device}'")
            else:
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
            self._rank = 0
            self._world_size = 1

101
        elif not use_accelerate:
102
            self._device = "cpu"
103
104
105
106
107
108
109
110
111
112
113
114
        else:
            self._device = device

        model_kwargs = {}
        if use_accelerate:
            model_kwargs = _get_accelerate_args(
                device_map_option,
                max_memory_per_gpu,
                max_cpu_memory,
                offload_folder,
            )
        print(model_kwargs)
115
116
117
118

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

haileyschoelkopf's avatar
haileyschoelkopf committed
119
        # get config
120
121
122
123
124
125
126
127
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
        )

        if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
            self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
        else:
haileyschoelkopf's avatar
haileyschoelkopf committed
128
            self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
129

haileyschoelkopf's avatar
haileyschoelkopf committed
130
131
132
133
        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
134

135
        self._model = self.AUTO_MODEL_CLASS.from_pretrained(
136
137
138
139
140
141
            pretrained,
            revision=revision,
            low_cpu_mem_usage=low_cpu_mem_usage,
            **model_kwargs,
            torch_dtype=utils.get_dtype(dtype),
        )  # .to(self.device)
142
        # forever after, access self._model through self.model property
143
        self.model.eval()
144
        # TODO: call self.model.tie_weights() here
145
146
147
148
149
150
151
152

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
        )

        self.vocab_size = self.tokenizer.vocab_size

153
154
        self._max_length = max_length

155
156
157
        # multithreading and batching
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size

158
159
160
161
162
163
164
        # if use_accelerate:
        #     if "lm_head" in self.model.hf_device_map:
        #         # `accelerate` can place `lm_head` weights on a different device than
        #         # the user specified one so we force `self._device` to be the same as
        #         # `lm_head`'s.
        #         self._device = self.model.hf_device_map["lm_head"]
        print(self._device, self.model.hf_device_map)
165
        # multigpu support with accelerate
166
        if gpus > 1 and not use_accelerate:
167
168
            accelerator = Accelerator()
            if gpus > accelerator.num_processes:
169
                # TODO: make sure there's still never an edge case where we unintentionally default to CPU
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
                eval_logger.warning(
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
                    f"Current run will proceed with {accelerator.num_processes} devices."
                )
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
                # manually set model to use gpu, for case where many GPUs available but
                # only seek to use one
                self._device = (
                    torch.device(f"cuda:{accelerator.local_process_index}")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
                self.model.to(self.device)
            else:
haileyschoelkopf's avatar
haileyschoelkopf committed
187
                self._model = accelerator.prepare(self.model)
188
189
190
191
192
193
194
195
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator

                if self.accelerator.is_local_main_process:
                    eval_logger.info(f"Using {gpus} devices with data parallelism")

                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
haileyschoelkopf's avatar
haileyschoelkopf committed
196

197
198
199
200
201
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

202
203
204
205
206
207
208
209
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

210
211
212
213
214
215
216
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
217
218
219
220
221
222
223
224
225
226
227
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def tok_encode(self, string: str, left_truncate_len=None):
haileyschoelkopf's avatar
haileyschoelkopf committed
250
        """ """
251
252
253
254
255
256
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
haileyschoelkopf's avatar
haileyschoelkopf committed
257

258
259
260
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
261

262
263
264
265
266
267
268
269
270
271
        return encoding

    def tok_decode(self, tokens):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            return self.tokenizer.decode(tokens)
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            return self.tokenizer.decode(tokens, skip_special_tokens=True)

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
272
        :param inps: torch.Tensor
273
274
275
276
277
278
279
280
281
282
283
284
285
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
286
287
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
288
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
289
290
291
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
292
293
294
295
296
297
298
299
300
301
302
303
304
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        # we require users to pass do_sample=True explicitly
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, 1, context.shape[0]
        )
305
306
307
308
309
310
311
312
        return self.model.generate(
            context,
            max_length=max_length,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.eot_token_id,
            use_cache=True,
            **generation_kwargs,
        )
313
314
315

    def _select_cont_toks(self, logits, contlen=None, inplen=None):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
316
317
318
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
319
320
321
322
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
323
324
325
326
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
327
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
328
329
            logits = logits[:contlen]

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        return logits

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        loglikelihoods = []
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
haileyschoelkopf's avatar
haileyschoelkopf committed
355
                        prefix_token=self.eot_token_id,
356
357
358
359
360
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
361
362

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
            self.batch_size,
        ):

            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
433
                # how this all works (illustrated on a causal decoder-only setup):
434
435
436
437
438
439
440
441
442
443
444
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
445
446
                        device=self.device,
                    )
447
448
449
450
451
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
452
                        device=self.device,
453
                    )
454
                    (inplen,) = inp.shape
455
456
457
458

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

459
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
460
                        (continuation_enc)[-self.max_length :],
461
462
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
463
                        dtype=torch.long,
464
465
                        device=self.device,
                    )
466
467
                    (contlen,) = cont.shape

468
469
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
470
471
472
473
474
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
475

haileyschoelkopf's avatar
haileyschoelkopf committed
476
477
478
479
480
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
481
482
483
484

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
485

486
487
488
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
489
490
491
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
492
493
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
haileyschoelkopf's avatar
haileyschoelkopf committed
494
495
496
497
498
499
500
501
502
503
504
505
506
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = utils.pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = utils.pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
507
508
509

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
510
            )  # [batch, padding_length (inp or cont), vocab]
511
512
513
514
515
516
517

            for (cache_key, _, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
518
                # take only logits in the continuation
519
                # (discard context toks if decoder-only ; discard right-padding)
haileyschoelkopf's avatar
haileyschoelkopf committed
520
521
522
523
524
                ctx_len = (
                    inplen
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
525
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
526
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
527
528
529

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
530
531
532
                cont_toks = torch.tensor(
                    cont_toks, dtype=torch.long, device=self.device
                ).unsqueeze(
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

        return re_ord.get_original(res)

    def greedy_until(self, requests):
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer([req.args for req in requests], _collate)

        for context, gen_kwargs in tqdm(re_ord.get_reordered()):
            until = None
            if isinstance(gen_kwargs, dict):
                gen_kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in gen_kwargs.keys():
                    until = gen_kwargs.pop("until")
                    if isinstance(until, str):
                        until = [gen_kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
                )
            if not until:
                until = [self.tok_decode(self.eot_token_id)]
            if "max_gen_toks" in gen_kwargs.keys():
                max_gen_toks = gen_kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
            # first stop sequence is used to halt generation upon encountering
            (primary_until) = until[0]

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            context_enc = torch.tensor(
593
594
595
                [self.tok_encode(context, left_truncate_len=max_ctx_len)],
                device=self.device,
            )
596
597

            cont = self._model_generate(
haileyschoelkopf's avatar
haileyschoelkopf committed
598
599
                context=context_enc,
                max_length=context_enc.shape[1] + max_gen_toks,
600
601
602
                stop=primary_until,
                **gen_kwargs,
            )
603
604
605
606
607
608
609

            cont_toks_list = cont[0].tolist()
            # discard context toks if using causal decoder-only LM
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                cont_toks_list = cont_toks_list[context_enc.shape[1] :]

            s = self.tok_decode(cont_toks_list)
610
611
612

            # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
            for term in until:
haileyschoelkopf's avatar
haileyschoelkopf committed
613
614
                if len(term) > 0:  # ignore '' separator, for seq2seq case where
                    s = s.split(term)[0]
615
616
617

            res.append(s)

haileyschoelkopf's avatar
haileyschoelkopf committed
618
        return re_ord.get_original(res)