seq2seq_lm.py 24.8 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
    input_ids: Optional[torch.Tensor]
29
30
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
    def to_pb(self) -> generate_pb2.Batch:
58
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
59
        return generate_pb2.Batch(
60
            id=self.batch_id, requests=self.requests, size=len(self)
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
    ) -> "Seq2SeqLMBatch":
70
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
71
72
73
74
75
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
76
77
        offsets = []
        token_offsets = []
78
        requests_idx_mapping = {}
79
80

        # Parse batch
81
        max_truncation = 0
82
        padding_right_offset = 0
83
        for i, r in enumerate(pb.requests):
84
            inputs.append(r.inputs)
85
            requests_idx_mapping[r.id] = i
86
            decoder_input_lengths.append(1)
87
88
            offsets.append(None)
            token_offsets.append(None)
89
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
90
91
92
93
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
94
            max_truncation = max(max_truncation, r.truncate)
95
96
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
97
98
            )

OlivierDehaene's avatar
OlivierDehaene committed
99
        # Tokenize batch
100
        tokenized_inputs = tokenizer(
101
102
103
            inputs,
            return_tensors="pt",
            padding=True,
104
            return_token_type_ids=False,
105
106
            truncation=True,
            max_length=max_truncation,
107
        ).to(device)
108
109
110
111

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

112
113
114
115
116
117
118
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
119
120
121
122

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
123
            requests_idx_mapping=requests_idx_mapping,
124
125
126
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
127
            all_decoder_input_ids=list(all_decoder_input_ids),
128
129
130
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
131
            input_lengths=input_lengths.tolist(),
132
            decoder_input_lengths=decoder_input_lengths,
133
134
            offsets=offsets,
            token_offsets=token_offsets,
135
136
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
137
            max_input_length=max_input_length.item(),
138
            max_decoder_input_length=1,
139
            padding_right_offset=padding_right_offset,
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
        offsets = []
        token_offsets = []

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0
167
        padding_right_offset = 0
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )
188
189
190
191
            padding_right_offset = max(
                padding_right_offset,
                self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
            )
192
193
194
195
196

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
197
198
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
199
        if self.decoder_attention_mask is not None:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            self.decoder_attention_mask = self.decoder_attention_mask[
              keep_indices,
              -(self.padding_right_offset + max_decoder_input_length):
              (self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
            ]

        self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [[t for t in layer] for layer in self.past_key_values]

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset

        return self
234
235


236
    @classmethod
237
    @tracer.start_as_current_span("concatenate")
238
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
239
240
        """Concatenate multiple batches together by padding internal torch tensors"""

241
        # Used for padding
242
243
244
245
246
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
247
            total_batch_size += len(batch)
248
249
250
251
252
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
253
254
255

        # Batch attributes
        requests = []
256
257
        requests_idx_mapping = {}
        all_decoder_input_ids = []
258
259
        input_lengths = []
        decoder_input_lengths = []
260
261
        offsets = []
        token_offsets = []
262
263
264
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
265
        # Batch tensors
266
267
268
269
270
271
272
273
274
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
275

276
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
277
            # Extend all list attributes
278
            requests.extend(batch.requests)
279
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
280
281
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
282
283
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
284
285
286
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

287
288
289
290
291
292
293
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

294
            # Slicing end index for this batch
295
            end_index = start_index + len(batch)
296
297
298
299
300

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
301
            # Create padded tensor
302
            if attention_mask is None:
303
                attention_mask = batch.attention_mask.new_zeros(
304
305
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
306
            # Copy to correct indices
307
308
309
310
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
311
            # Create padded tensor
312
            if decoder_input_ids is None:
313
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
314
                    (total_batch_size, 1),
315
                )
OlivierDehaene's avatar
OlivierDehaene committed
316
            # Copy to correct indices
317
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
318

OlivierDehaene's avatar
OlivierDehaene committed
319
            # Create padded tensor
320
            if decoder_attention_mask is None:
321
322
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
323
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
324
                )
OlivierDehaene's avatar
OlivierDehaene committed
325
326
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
327
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
328
329
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
330
331
                    start_index:end_index,
                    left_offset:-padding_right_offset,
332
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
333
            # If it exists, we need to index
334
            else:
335
336
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
337
338
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
339
                )
340
                decoder_attention_mask[
341
342
343
344
345
346
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
347

OlivierDehaene's avatar
OlivierDehaene committed
348
            # Create padded tensor
349
            if encoder_last_hidden_state is None:
350
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
351
352
353
354
355
356
357
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
358
            # Copy to correct indices
359
            encoder_last_hidden_state[
360
361
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
362
            batch.encoder_last_hidden_state = None
363

364
365
366
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
367

368
            start_index = end_index
369

370
371
372
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
373

374
375
376
377
378
379
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
380

381
382
383
384
385
386
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
387

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
                    padded_past_values[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = t[:, :, -past_seq_len:, :]
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
                        start_index:end_index, :, -batch.max_input_length:, :
                    ] = t[:, :, -batch.max_input_length:, :]
                    del t
432

433
                    start_index = end_index
434
435
436
437

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
438
            requests_idx_mapping=requests_idx_mapping,
439
            input_ids=None,
440
441
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
442
            all_decoder_input_ids=all_decoder_input_ids,
443
444
445
446
447
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
448
449
            offsets=offsets,
            token_offsets=token_offsets,
450
451
452
453
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
454
            padding_right_offset=padding_right_offset,
455
456
        )

457
458
459
    def __len__(self):
        return len(self.requests)

460
461

class Seq2SeqLM(Model):
462
463
464
465
466
467
468
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
469
470
471
472
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
473
474
475
            if quantize:
                raise ValueError("quantization is not available on CPU")

476
477
478
479
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
480
            model_id,
481
            revision=revision,
482
483
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
484
            load_in_8bit=quantize,
485
        ).eval()
486
        tokenizer = AutoTokenizer.from_pretrained(
487
            model_id, revision=revision, padding_side="left", truncation_side="left"
488
        )
489
490
491
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
492
493
494
495
496
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
497
498
499
500
501
502
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

503
    def decode(self, decoder_ids: List[int]) -> str:
504
505
506
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
507

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
527
            encoder_outputs=encoder_last_hidden_state,
528
529
530
531
532
533
534
535
536
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

537
    @tracer.start_as_current_span("generate_token")
538
539
    def generate_token(
        self, batch: Seq2SeqLMBatch
540
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
541
542
543
544
545
546
547
548
549
550
551
552
553
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
554
            encoder_last_hidden_state = None
555

556
557
558
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
559
            batch.decoder_input_ids,
560
561
            decoder_attention_mask,
            encoder_last_hidden_state,
562
            batch.past_key_values,
563
564
565
        )

        # Finished requests
566
        generations: List[Generation] = []
567
        stopped = True
568
569
570
571
572

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
573
574
            batch.offsets,
            batch.token_offsets,
575
576
577
578
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
579
            batch.all_decoder_input_ids,
580
581
582
583
584
585
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
586
587
            offset,
            token_offset,
588
589
590
591
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
592
            all_decoder_input_ids,
593
594
        ) in enumerate(iterator):
            # Select next token
595
            next_token_id, logprobs = next_token_chooser(
596
                all_decoder_input_ids.view(1, -1), logits
597
            )
598
599

            # Append next token to decoder tokens
600
601
602
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
603
604
            new_decoder_input_length = decoder_input_length + 1

605
606
607
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
608
            next_token_text, offset, token_offset = self.decode_token(
609
                all_decoder_input_ids, offset, token_offset
610
            )
611
612

            # Evaluate stopping criteria
613
614
            stop, reason = stopping_criteria(next_token_id, next_token_text)

615
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
616
617
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
618
                output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
619
620
621
622
623
624
625

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

626
627
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
628
629
                )
            else:
630
631
                # Keep request in the batch
                generated_text = None
632
                stopped = False
633

634
635
636
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
637
638
639
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
640
641
642
643
644
645
646
647
648
649
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
650
                next_token_id_squeezed.item() in self.all_special_ids,
651
652
653
654
655
                generated_text,
            )

            generations.append(generation)

656
657
658
659
660
661
662
663
664
665
666
667
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

668
        # We finished all generations in the batch; there is no next batch
669
        if stopped:
670
            return generations, None
671

672
673
674
675
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
676
        # Update decoder_attention_mask as we added a new token to input_ids
677
678
679
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
680

681
        return generations, batch