seq2seq_lm.py 23.4 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
29
30
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
    def to_pb(self) -> generate_pb2.Batch:
58
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
59
        return generate_pb2.Batch(
60
            id=self.batch_id, requests=self.requests, size=len(self)
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
    ) -> "Seq2SeqLMBatch":
70
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
71
72
73
74
75
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
76
77
        offsets = []
        token_offsets = []
78
        requests_idx_mapping = {}
79
80

        # Parse batch
81
        max_truncation = 0
82
        padding_right_offset = 0
83
        for i, r in enumerate(pb.requests):
84
            inputs.append(r.inputs)
85
            requests_idx_mapping[r.id] = i
86
            decoder_input_lengths.append(1)
87
88
            offsets.append(None)
            token_offsets.append(None)
89
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
90
91
92
93
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
94
            max_truncation = max(max_truncation, r.truncate)
95
96
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
97
98
            )

OlivierDehaene's avatar
OlivierDehaene committed
99
        # Tokenize batch
100
        tokenized_inputs = tokenizer(
101
102
103
            inputs,
            return_tensors="pt",
            padding=True,
104
            return_token_type_ids=False,
105
106
            truncation=True,
            max_length=max_truncation,
107
        ).to(device)
108
109
110
111

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

112
113
114
115
116
117
118
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
119
120
121
122

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
123
            requests_idx_mapping=requests_idx_mapping,
124
125
126
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
127
            all_decoder_input_ids=list(all_decoder_input_ids),
128
129
130
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
131
            input_lengths=input_lengths.tolist(),
132
            decoder_input_lengths=decoder_input_lengths,
133
134
            offsets=offsets,
            token_offsets=token_offsets,
135
136
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
137
            max_input_length=max_input_length.item(),
138
            max_decoder_input_length=1,
139
            padding_right_offset=padding_right_offset,
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
        offsets = []
        token_offsets = []

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0

        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        decoder_input_ids = self.decoder_input_ids[keep_indices]
        attention_mask = self.attention_mask[keep_indices]
        if self.decoder_attention_mask is not None:
            decoder_attention_mask = self.decoder_attention_mask[keep_indices]
        else:
            decoder_attention_mask = None

        encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]

        past_key_values = [
            [t[keep_indices] for t in layer] for layer in self.past_key_values
        ]

        return Seq2SeqLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=None,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            all_decoder_input_ids=all_decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
            padding_right_offset=self.padding_right_offset,
        )

227
    @classmethod
228
    @tracer.start_as_current_span("concatenate")
229
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
230
231
        """Concatenate multiple batches together by padding internal torch tensors"""

232
        # Used for padding
233
234
235
236
237
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
238
            total_batch_size += len(batch)
239
240
241
242
243
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
244
245
246

        # Batch attributes
        requests = []
247
248
        requests_idx_mapping = {}
        all_decoder_input_ids = []
249
250
        input_lengths = []
        decoder_input_lengths = []
251
252
        offsets = []
        token_offsets = []
253
254
255
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
256
        # Batch tensors
257
258
259
260
261
262
263
264
265
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
266

267
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
268
            # Extend all list attributes
269
            requests.extend(batch.requests)
270
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
271
272
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
273
274
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
275
276
277
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

278
279
280
281
282
283
284
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

285
            # Slicing end index for this batch
286
            end_index = start_index + len(batch)
287
288
289
290
291

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
292
            # Create padded tensor
293
            if attention_mask is None:
294
                attention_mask = batch.attention_mask.new_zeros(
295
296
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
297
            # Copy to correct indices
298
299
300
301
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
302
            # Create padded tensor
303
            if decoder_input_ids is None:
304
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
305
                    (total_batch_size, 1),
306
                )
OlivierDehaene's avatar
OlivierDehaene committed
307
            # Copy to correct indices
308
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
309

OlivierDehaene's avatar
OlivierDehaene committed
310
            # Create padded tensor
311
            if decoder_attention_mask is None:
312
313
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
314
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
315
                )
OlivierDehaene's avatar
OlivierDehaene committed
316
317
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
318
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
319
320
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
321
322
                    start_index:end_index,
                    left_offset:-padding_right_offset,
323
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
324
            # If it exists, we need to index
325
            else:
326
327
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
328
329
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
330
                )
331
                decoder_attention_mask[
332
333
334
335
336
337
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
338

OlivierDehaene's avatar
OlivierDehaene committed
339
            # Create padded tensor
340
            if encoder_last_hidden_state is None:
341
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
342
343
344
345
346
347
348
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
349
            # Copy to correct indices
350
            encoder_last_hidden_state[
351
352
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
353

OlivierDehaene's avatar
OlivierDehaene committed
354
            # Iterate over attention layers
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
374
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
398
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
399
400
401
402
403

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

404
            start_index += len(batch)
405
406
407
408

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
409
            requests_idx_mapping=requests_idx_mapping,
410
            input_ids=None,
411
412
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
413
            all_decoder_input_ids=all_decoder_input_ids,
414
415
416
417
418
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
419
420
            offsets=offsets,
            token_offsets=token_offsets,
421
422
423
424
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
425
            padding_right_offset=padding_right_offset,
426
427
        )

428
429
430
    def __len__(self):
        return len(self.requests)

431
432

class Seq2SeqLM(Model):
433
434
435
436
437
438
439
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
440
441
442
443
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
444
445
446
            if quantize:
                raise ValueError("quantization is not available on CPU")

447
448
449
450
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
451
            model_id,
452
            revision=revision,
453
454
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
455
            load_in_8bit=quantize,
456
        ).eval()
457
        tokenizer = AutoTokenizer.from_pretrained(
458
            model_id, revision=revision, padding_side="left", truncation_side="left"
459
        )
460
461
462
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
463
            tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
464
465
466
467
468
469
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

470
    def decode(self, decoder_ids: List[int]) -> str:
471
472
473
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
474

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
494
            encoder_outputs=encoder_last_hidden_state,
495
496
497
498
499
500
501
502
503
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

504
    @tracer.start_as_current_span("generate_token")
505
506
    def generate_token(
        self, batch: Seq2SeqLMBatch
507
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
508
509
510
511
512
513
514
515
516
517
518
519
520
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
521
            encoder_last_hidden_state = None
522

523
524
525
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
526
            batch.decoder_input_ids,
527
528
            decoder_attention_mask,
            encoder_last_hidden_state,
529
            batch.past_key_values,
530
531
532
        )

        # Finished requests
533
        generations: List[Generation] = []
534
        stopped = True
535
536
537
538
539

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
540
541
            batch.offsets,
            batch.token_offsets,
542
543
544
545
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
546
            batch.all_decoder_input_ids,
547
548
549
550
551
552
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
553
554
            offset,
            token_offset,
555
556
557
558
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
559
            all_decoder_input_ids,
560
561
        ) in enumerate(iterator):
            # Select next token
562
            next_token_id, logprobs = next_token_chooser(
563
                all_decoder_input_ids.view(1, -1), logits
564
            )
565
566

            # Append next token to decoder tokens
567
568
569
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
570
571
            new_decoder_input_length = decoder_input_length + 1

572
573
574
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
575
            next_token_text, offset, token_offset = self.decode_token(
576
                all_decoder_input_ids, offset, token_offset
577
            )
578
579

            # Evaluate stopping criteria
580
581
            stop, reason = stopping_criteria(next_token_id, next_token_text)

582
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
583
584
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
585
                output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
586
587
588
589
590
591
592

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

593
594
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
595
596
                )
            else:
597
598
                # Keep request in the batch
                generated_text = None
599
                stopped = False
600

601
602
603
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
604
605
606
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
607
608
609
610
611
612
613
614
615
616
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
617
                next_token_id_squeezed.item() in self.all_special_ids,
618
619
620
621
622
                generated_text,
            )

            generations.append(generation)

623
624
625
626
627
628
629
630
631
632
633
634
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

635
        # We finished all generations in the batch; there is no next batch
636
        if stopped:
637
            return generations, None
638

639
640
641
642
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
643
        # Update decoder_attention_mask as we added a new token to input_ids
644
645
646
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
647

648
        return generations, batch