seq2seq_lm.py 26.3 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
    input_ids: Optional[torch.Tensor]
29
30
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
58
59
    # Maximum number of tokens this batch will grow to
    max_tokens: int

60
    def to_pb(self) -> generate_pb2.Batch:
61
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
62
        return generate_pb2.Batch(
63
64
65
66
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
67
68
69
70
        )

    @classmethod
    def from_pb(
71
72
73
74
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
75
    ) -> "Seq2SeqLMBatch":
76
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
77
78
79
80
81
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
82
83
        offsets = []
        token_offsets = []
84
        requests_idx_mapping = {}
85
86

        # Parse batch
87
        max_truncation = 0
88
        padding_right_offset = 0
89
        max_decode_tokens = 0
90
        for i, r in enumerate(pb.requests):
91
            inputs.append(r.inputs)
92
            requests_idx_mapping[r.id] = i
93
            decoder_input_lengths.append(1)
94
95
            offsets.append(None)
            token_offsets.append(None)
96
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
97
98
99
100
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
101
            max_truncation = max(max_truncation, r.truncate)
102
            max_decode_tokens += stopping_criteria.max_new_tokens
103
104
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
105
106
            )

OlivierDehaene's avatar
OlivierDehaene committed
107
        # Tokenize batch
108
        tokenized_inputs = tokenizer(
109
110
111
            inputs,
            return_tensors="pt",
            padding=True,
112
            return_token_type_ids=False,
113
114
            truncation=True,
            max_length=max_truncation,
115
        ).to(device)
116
117
118
119

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

120
121
122
123
124
125
126
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
127

128
129
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

130
131
132
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
133
            requests_idx_mapping=requests_idx_mapping,
134
135
136
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
137
            all_decoder_input_ids=list(all_decoder_input_ids),
138
139
140
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
141
            input_lengths=input_lengths.tolist(),
142
            decoder_input_lengths=decoder_input_lengths,
143
144
            offsets=offsets,
            token_offsets=token_offsets,
145
146
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
147
            max_input_length=max_input_length.item(),
148
            max_decoder_input_length=1,
149
            padding_right_offset=padding_right_offset,
150
            max_tokens=max_tokens,
151
152
        )

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
        offsets = []
        token_offsets = []

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0
178
        padding_right_offset = 0
179

180
        total_remaining_decode_tokens = 0
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
203
204
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
205
            remaining_decode_tokens = (
206
207
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
208
209
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
210
211

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
212
213
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
214
        if self.decoder_attention_mask is not None:
215
            self.decoder_attention_mask = self.decoder_attention_mask[
216
217
218
219
220
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
221
222
            ]

223
224
225
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
226
227
228

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
229
230
231
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
232
233
234
235
236
237
238
239

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

240
241
242
243
244
        max_tokens = (
            len(requests) * (max_input_length + max_decoder_input_length)
            + remaining_decode_tokens
        )

245
246
247
248
249
250
251
252
253
254
255
256
257
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
258
        self.max_tokens = max_tokens
259
260

        return self
261

262
    @classmethod
263
    @tracer.start_as_current_span("concatenate")
264
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
265
266
        """Concatenate multiple batches together by padding internal torch tensors"""

267
        # Used for padding
268
269
270
271
272
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
273
            total_batch_size += len(batch)
274
275
276
277
278
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
279
280
281

        # Batch attributes
        requests = []
282
283
        requests_idx_mapping = {}
        all_decoder_input_ids = []
284
285
        input_lengths = []
        decoder_input_lengths = []
286
287
        offsets = []
        token_offsets = []
288
289
        next_token_choosers = []
        stopping_criterias = []
290
        max_tokens = 0
291

OlivierDehaene's avatar
OlivierDehaene committed
292
        # Batch tensors
293
294
295
296
297
298
299
300
301
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
302

303
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
304
            # Extend all list attributes
305
            requests.extend(batch.requests)
306
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
307
308
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
309
310
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
311
312
313
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

314
315
316
317
318
319
320
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

321
            # Slicing end index for this batch
322
            end_index = start_index + len(batch)
323
324
325
326
327

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
328
            # Create padded tensor
329
            if attention_mask is None:
330
                attention_mask = batch.attention_mask.new_zeros(
331
332
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
333
            # Copy to correct indices
334
335
336
337
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
338
            # Create padded tensor
339
            if decoder_input_ids is None:
340
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
341
                    (total_batch_size, 1),
342
                )
OlivierDehaene's avatar
OlivierDehaene committed
343
            # Copy to correct indices
344
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
345

OlivierDehaene's avatar
OlivierDehaene committed
346
            # Create padded tensor
347
            if decoder_attention_mask is None:
348
349
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
350
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
351
                )
OlivierDehaene's avatar
OlivierDehaene committed
352
353
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
354
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
355
356
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
357
358
                    start_index:end_index,
                    left_offset:-padding_right_offset,
359
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
360
            # If it exists, we need to index
361
            else:
362
363
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
364
365
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
366
                )
367
                decoder_attention_mask[
368
369
370
371
372
373
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
374

OlivierDehaene's avatar
OlivierDehaene committed
375
            # Create padded tensor
376
            if encoder_last_hidden_state is None:
377
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
378
379
380
381
382
383
384
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
385
            # Copy to correct indices
386
            encoder_last_hidden_state[
387
388
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
389
            batch.encoder_last_hidden_state = None
390

391
392
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
393
394
395
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
396

397
398
399
400
401
402
403
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
404

405
406
            start_index = end_index

407
408
409
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
410

411
412
413
414
415
416
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
417

418
419
420
421
422
423
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
444
445
446
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
466
467
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
468
                    del t
469

470
                    start_index = end_index
471
472
473
474

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
475
            requests_idx_mapping=requests_idx_mapping,
476
            input_ids=None,
477
478
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
479
            all_decoder_input_ids=all_decoder_input_ids,
480
481
482
483
484
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
485
486
            offsets=offsets,
            token_offsets=token_offsets,
487
488
489
490
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
491
            padding_right_offset=padding_right_offset,
492
            max_tokens=max_tokens,
493
494
        )

495
496
497
    def __len__(self):
        return len(self.requests)

498
499

class Seq2SeqLM(Model):
500
501
502
503
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
504
        quantize: Optional[str] = None,
505
506
        decode_buffer: int = 3,
    ):
507
508
        if torch.cuda.is_available():
            device = torch.device("cuda")
509
            dtype = torch.float16
510
        else:
511
512
513
            if quantize:
                raise ValueError("quantization is not available on CPU")

514
515
516
517
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
518
            model_id,
519
            revision=revision,
520
521
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
522
            load_in_8bit=quantize == "bitsandbytes",
523
        ).eval()
524
        tokenizer = AutoTokenizer.from_pretrained(
525
            model_id, revision=revision, padding_side="left", truncation_side="left"
526
        )
527
528
529
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
530
531
532
533
534
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
535
536
537
538
539
540
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

541
    def decode(self, decoder_ids: List[int]) -> str:
542
543
544
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
545

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
565
            encoder_outputs=encoder_last_hidden_state,
566
567
568
569
570
571
572
573
574
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

575
    @tracer.start_as_current_span("generate_token")
576
577
    def generate_token(
        self, batch: Seq2SeqLMBatch
578
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
579
580
581
582
583
584
585
586
587
588
589
590
591
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
592
            encoder_last_hidden_state = None
593

594
595
596
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
597
            batch.decoder_input_ids,
598
599
            decoder_attention_mask,
            encoder_last_hidden_state,
600
            batch.past_key_values,
601
602
603
        )

        # Finished requests
604
        generations: List[Generation] = []
605
        stopped = True
606
607
608
609
610

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
611
612
            batch.offsets,
            batch.token_offsets,
613
614
615
616
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
617
            batch.all_decoder_input_ids,
618
619
620
621
622
623
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
624
625
            offset,
            token_offset,
626
627
628
629
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
630
            all_decoder_input_ids,
631
632
        ) in enumerate(iterator):
            # Select next token
633
            next_token_id, logprobs = next_token_chooser(
634
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
635
            )
636
637

            # Append next token to decoder tokens
638
639
640
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
641
642
            new_decoder_input_length = decoder_input_length + 1

643
644
645
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
646
            next_token_text, offset, token_offset = self.decode_token(
647
                all_decoder_input_ids, offset, token_offset
648
            )
649
650

            # Evaluate stopping criteria
651
652
            stop, reason = stopping_criteria(next_token_id, next_token_text)

653
            if not stop:
654
                stopped = False
655

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
                    output_text = self.decode(
                        all_decoder_input_ids[-decoder_input_length:]
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    prefill_tokens = PrefillTokens(
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
696
697
                )

698
                generations.append(generation)
699

700
701
702
703
704
705
706
707
708
709
710
711
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

712
        # We finished all generations in the batch; there is no next batch
713
        if stopped:
714
            return generations, None
715

716
717
718
719
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
720
        # Update decoder_attention_mask as we added a new token to input_ids
721
722
723
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
724

725
        return generations, batch