seq2seq_lm.py 23.5 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
29
30
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
    def to_pb(self) -> generate_pb2.Batch:
58
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
59
        return generate_pb2.Batch(
60
            id=self.batch_id, requests=self.requests, size=len(self)
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
    ) -> "Seq2SeqLMBatch":
70
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
71
72
73
74
75
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
76
77
        offsets = []
        token_offsets = []
78
        requests_idx_mapping = {}
79
80

        # Parse batch
81
        max_truncation = 0
82
        padding_right_offset = 0
83
        for i, r in enumerate(pb.requests):
84
            inputs.append(r.inputs)
85
            requests_idx_mapping[r.id] = i
86
            decoder_input_lengths.append(1)
87
88
            offsets.append(None)
            token_offsets.append(None)
89
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
90
91
92
93
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
94
            max_truncation = max(max_truncation, r.truncate)
95
96
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
97
98
            )

OlivierDehaene's avatar
OlivierDehaene committed
99
        # Tokenize batch
100
        tokenized_inputs = tokenizer(
101
102
103
            inputs,
            return_tensors="pt",
            padding=True,
104
            return_token_type_ids=False,
105
106
            truncation=True,
            max_length=max_truncation,
107
        ).to(device)
108
109
110
111

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

112
113
114
115
116
117
118
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
119
120
121
122

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
123
            requests_idx_mapping=requests_idx_mapping,
124
125
126
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
127
            all_decoder_input_ids=list(all_decoder_input_ids),
128
129
130
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
131
            input_lengths=input_lengths.tolist(),
132
            decoder_input_lengths=decoder_input_lengths,
133
134
            offsets=offsets,
            token_offsets=token_offsets,
135
136
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
137
            max_input_length=max_input_length.item(),
138
            max_decoder_input_length=1,
139
            padding_right_offset=padding_right_offset,
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
        offsets = []
        token_offsets = []

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0

        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        decoder_input_ids = self.decoder_input_ids[keep_indices]
        attention_mask = self.attention_mask[keep_indices]
        if self.decoder_attention_mask is not None:
            decoder_attention_mask = self.decoder_attention_mask[keep_indices]
        else:
            decoder_attention_mask = None

        encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]

        past_key_values = [
            [t[keep_indices] for t in layer] for layer in self.past_key_values
        ]

        return Seq2SeqLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=None,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            all_decoder_input_ids=all_decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
            padding_right_offset=self.padding_right_offset,
        )

227
    @classmethod
228
    @tracer.start_as_current_span("concatenate")
229
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
230
231
        """Concatenate multiple batches together by padding internal torch tensors"""

232
        # Used for padding
233
234
235
236
237
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
238
            total_batch_size += len(batch)
239
240
241
242
243
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
244
245
246

        # Batch attributes
        requests = []
247
248
        requests_idx_mapping = {}
        all_decoder_input_ids = []
249
250
        input_lengths = []
        decoder_input_lengths = []
251
252
        offsets = []
        token_offsets = []
253
254
255
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
256
        # Batch tensors
257
258
259
260
261
262
263
264
265
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
266

267
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
268
            # Extend all list attributes
269
            requests.extend(batch.requests)
270
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
271
272
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
273
274
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
275
276
277
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

278
279
280
281
282
283
284
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

285
            # Slicing end index for this batch
286
            end_index = start_index + len(batch)
287
288
289
290
291

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
292
            # Create padded tensor
293
            if attention_mask is None:
294
                attention_mask = batch.attention_mask.new_zeros(
295
296
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
297
            # Copy to correct indices
298
299
300
301
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
302
            # Create padded tensor
303
            if decoder_input_ids is None:
304
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
305
                    (total_batch_size, 1),
306
                )
OlivierDehaene's avatar
OlivierDehaene committed
307
            # Copy to correct indices
308
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
309

OlivierDehaene's avatar
OlivierDehaene committed
310
            # Create padded tensor
311
            if decoder_attention_mask is None:
312
313
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
314
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
315
                )
OlivierDehaene's avatar
OlivierDehaene committed
316
317
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
318
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
319
320
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
321
322
                    start_index:end_index,
                    left_offset:-padding_right_offset,
323
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
324
            # If it exists, we need to index
325
            else:
326
327
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
328
329
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
330
                )
331
                decoder_attention_mask[
332
333
334
335
336
337
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
338

OlivierDehaene's avatar
OlivierDehaene committed
339
            # Create padded tensor
340
            if encoder_last_hidden_state is None:
341
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
342
343
344
345
346
347
348
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
349
            # Copy to correct indices
350
            encoder_last_hidden_state[
351
352
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
353

OlivierDehaene's avatar
OlivierDehaene committed
354
            # Iterate over attention layers
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
374
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
398
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
399
400
401
402
403

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

404
            start_index += len(batch)
405
406
407
408

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
409
            requests_idx_mapping=requests_idx_mapping,
410
            input_ids=None,
411
412
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
413
            all_decoder_input_ids=all_decoder_input_ids,
414
415
416
417
418
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
419
420
            offsets=offsets,
            token_offsets=token_offsets,
421
422
423
424
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
425
            padding_right_offset=padding_right_offset,
426
427
        )

428
429
430
    def __len__(self):
        return len(self.requests)

431
432

class Seq2SeqLM(Model):
433
434
435
436
437
438
439
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
440
441
442
443
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
444
445
446
            if quantize:
                raise ValueError("quantization is not available on CPU")

447
448
449
450
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
451
            model_id,
452
            revision=revision,
453
454
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
455
            load_in_8bit=quantize,
456
        ).eval()
457
        tokenizer = AutoTokenizer.from_pretrained(
458
            model_id, revision=revision, padding_side="left", truncation_side="left"
459
        )
460
461
462
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
463
464
465
466
467
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
468
469
470
471
472
473
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

474
    def decode(self, decoder_ids: List[int]) -> str:
475
476
477
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
478

479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
498
            encoder_outputs=encoder_last_hidden_state,
499
500
501
502
503
504
505
506
507
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

508
    @tracer.start_as_current_span("generate_token")
509
510
    def generate_token(
        self, batch: Seq2SeqLMBatch
511
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
512
513
514
515
516
517
518
519
520
521
522
523
524
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
525
            encoder_last_hidden_state = None
526

527
528
529
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
530
            batch.decoder_input_ids,
531
532
            decoder_attention_mask,
            encoder_last_hidden_state,
533
            batch.past_key_values,
534
535
536
        )

        # Finished requests
537
        generations: List[Generation] = []
538
        stopped = True
539
540
541
542
543

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
544
545
            batch.offsets,
            batch.token_offsets,
546
547
548
549
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
550
            batch.all_decoder_input_ids,
551
552
553
554
555
556
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
557
558
            offset,
            token_offset,
559
560
561
562
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
563
            all_decoder_input_ids,
564
565
        ) in enumerate(iterator):
            # Select next token
566
            next_token_id, logprobs = next_token_chooser(
567
                all_decoder_input_ids.view(1, -1), logits
568
            )
569
570

            # Append next token to decoder tokens
571
572
573
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
574
575
            new_decoder_input_length = decoder_input_length + 1

576
577
578
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
579
            next_token_text, offset, token_offset = self.decode_token(
580
                all_decoder_input_ids, offset, token_offset
581
            )
582
583

            # Evaluate stopping criteria
584
585
            stop, reason = stopping_criteria(next_token_id, next_token_text)

586
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
587
588
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
589
                output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
590
591
592
593
594
595
596

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

597
598
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
599
600
                )
            else:
601
602
                # Keep request in the batch
                generated_text = None
603
                stopped = False
604

605
606
607
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
608
609
610
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
611
612
613
614
615
616
617
618
619
620
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
621
                next_token_id_squeezed.item() in self.all_special_ids,
622
623
624
625
626
                generated_text,
            )

            generations.append(generation)

627
628
629
630
631
632
633
634
635
636
637
638
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

639
        # We finished all generations in the batch; there is no next batch
640
        if stopped:
641
            return generations, None
642

643
644
645
646
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
647
        # Update decoder_attention_mask as we added a new token to input_ids
648
649
650
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
651

652
        return generations, batch