seq2seq_lm.py 26.1 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
    input_ids: Optional[torch.Tensor]
29
30
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
58
59
    # Maximum number of tokens this batch will grow to
    max_tokens: int

60
    def to_pb(self) -> generate_pb2.Batch:
61
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
62
        return generate_pb2.Batch(
63
64
65
66
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
67
68
69
70
        )

    @classmethod
    def from_pb(
71
72
73
74
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
75
    ) -> "Seq2SeqLMBatch":
76
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
77
78
79
80
81
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
82
83
        offsets = []
        token_offsets = []
84
        requests_idx_mapping = {}
85
86

        # Parse batch
87
        max_truncation = 0
88
        padding_right_offset = 0
89
        max_decode_tokens = 0
90
        for i, r in enumerate(pb.requests):
91
            inputs.append(r.inputs)
92
            requests_idx_mapping[r.id] = i
93
            decoder_input_lengths.append(1)
94
95
            offsets.append(None)
            token_offsets.append(None)
96
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
97
98
99
100
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
101
            max_truncation = max(max_truncation, r.truncate)
102
            max_decode_tokens += stopping_criteria.max_new_tokens
103
104
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
105
106
            )

OlivierDehaene's avatar
OlivierDehaene committed
107
        # Tokenize batch
108
        tokenized_inputs = tokenizer(
109
110
111
            inputs,
            return_tensors="pt",
            padding=True,
112
            return_token_type_ids=False,
113
114
            truncation=True,
            max_length=max_truncation,
115
        ).to(device)
116
117
118
119

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

120
121
122
123
124
125
126
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
127

128
129
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

130
131
132
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
133
            requests_idx_mapping=requests_idx_mapping,
134
135
136
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
137
            all_decoder_input_ids=list(all_decoder_input_ids),
138
139
140
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
141
            input_lengths=input_lengths.tolist(),
142
            decoder_input_lengths=decoder_input_lengths,
143
144
            offsets=offsets,
            token_offsets=token_offsets,
145
146
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
147
            max_input_length=max_input_length.item(),
148
            max_decoder_input_length=1,
149
            padding_right_offset=padding_right_offset,
150
            max_tokens=max_tokens,
151
152
        )

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
        offsets = []
        token_offsets = []

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0
178
        padding_right_offset = 0
179

180
181
        remaining_decode_tokens = 0

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )
201
202
            padding_right_offset = max(
                padding_right_offset,
203
204
                self.stopping_criterias[idx].max_new_tokens
                - self.stopping_criterias[idx].current_tokens,
205
            )
206
207

            next_token_choosers.append(self.next_token_choosers[idx])
208
209
210
211
212
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
            remaining_decode_tokens += (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
213
214

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
215
216
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
217
        if self.decoder_attention_mask is not None:
218
            self.decoder_attention_mask = self.decoder_attention_mask[
219
220
221
222
223
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
224
225
            ]

226
227
228
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
229
230
231

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
232
233
234
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
235
236
237
238
239
240
241
242

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

243
244
245
246
247
        max_tokens = (
            len(requests) * (max_input_length + max_decoder_input_length)
            + remaining_decode_tokens
        )

248
249
250
251
252
253
254
255
256
257
258
259
260
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
261
        self.max_tokens = max_tokens
262
263

        return self
264

265
    @classmethod
266
    @tracer.start_as_current_span("concatenate")
267
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
268
269
        """Concatenate multiple batches together by padding internal torch tensors"""

270
        # Used for padding
271
272
273
274
275
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
276
            total_batch_size += len(batch)
277
278
279
280
281
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
282
283
284

        # Batch attributes
        requests = []
285
286
        requests_idx_mapping = {}
        all_decoder_input_ids = []
287
288
        input_lengths = []
        decoder_input_lengths = []
289
290
        offsets = []
        token_offsets = []
291
292
        next_token_choosers = []
        stopping_criterias = []
293
        max_tokens = 0
294

OlivierDehaene's avatar
OlivierDehaene committed
295
        # Batch tensors
296
297
298
299
300
301
302
303
304
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
305

306
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
307
            # Extend all list attributes
308
            requests.extend(batch.requests)
309
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
310
311
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
312
313
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
314
315
316
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

317
318
319
320
321
322
323
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

324
            # Slicing end index for this batch
325
            end_index = start_index + len(batch)
326
327
328
329
330

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
331
            # Create padded tensor
332
            if attention_mask is None:
333
                attention_mask = batch.attention_mask.new_zeros(
334
335
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
336
            # Copy to correct indices
337
338
339
340
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
341
            # Create padded tensor
342
            if decoder_input_ids is None:
343
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
344
                    (total_batch_size, 1),
345
                )
OlivierDehaene's avatar
OlivierDehaene committed
346
            # Copy to correct indices
347
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
348

OlivierDehaene's avatar
OlivierDehaene committed
349
            # Create padded tensor
350
            if decoder_attention_mask is None:
351
352
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
353
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
354
                )
OlivierDehaene's avatar
OlivierDehaene committed
355
356
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
357
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
358
359
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
360
361
                    start_index:end_index,
                    left_offset:-padding_right_offset,
362
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
363
            # If it exists, we need to index
364
            else:
365
366
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
367
368
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
369
                )
370
                decoder_attention_mask[
371
372
373
374
375
376
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
377

OlivierDehaene's avatar
OlivierDehaene committed
378
            # Create padded tensor
379
            if encoder_last_hidden_state is None:
380
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
381
382
383
384
385
386
387
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
388
            # Copy to correct indices
389
            encoder_last_hidden_state[
390
391
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
392
            batch.encoder_last_hidden_state = None
393

394
395
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
396
397
398
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
399

400
            start_index = end_index
401
402
403
404
405
406
407
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
408

409
410
411
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
412

413
414
415
416
417
418
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
419

420
421
422
423
424
425
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
426

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
446
447
448
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
468
469
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
470
                    del t
471

472
                    start_index = end_index
473
474
475
476

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
477
            requests_idx_mapping=requests_idx_mapping,
478
            input_ids=None,
479
480
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
481
            all_decoder_input_ids=all_decoder_input_ids,
482
483
484
485
486
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
487
488
            offsets=offsets,
            token_offsets=token_offsets,
489
490
491
492
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
493
            padding_right_offset=padding_right_offset,
494
            max_tokens=max_tokens,
495
496
        )

497
498
499
    def __len__(self):
        return len(self.requests)

500
501

class Seq2SeqLM(Model):
502
503
504
505
506
507
508
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
509
510
511
512
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
513
514
515
            if quantize:
                raise ValueError("quantization is not available on CPU")

516
517
518
519
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
520
            model_id,
521
            revision=revision,
522
523
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
524
            load_in_8bit=quantize,
525
        ).eval()
526
        tokenizer = AutoTokenizer.from_pretrained(
527
            model_id, revision=revision, padding_side="left", truncation_side="left"
528
        )
529
530
531
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
532
533
534
535
536
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
537
538
539
540
541
542
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

543
    def decode(self, decoder_ids: List[int]) -> str:
544
545
546
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
547

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
567
            encoder_outputs=encoder_last_hidden_state,
568
569
570
571
572
573
574
575
576
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

577
    @tracer.start_as_current_span("generate_token")
578
579
    def generate_token(
        self, batch: Seq2SeqLMBatch
580
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
581
582
583
584
585
586
587
588
589
590
591
592
593
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
594
            encoder_last_hidden_state = None
595

596
597
598
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
599
            batch.decoder_input_ids,
600
601
            decoder_attention_mask,
            encoder_last_hidden_state,
602
            batch.past_key_values,
603
604
605
        )

        # Finished requests
606
        generations: List[Generation] = []
607
        stopped = True
608
609
610
611
612

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
613
614
            batch.offsets,
            batch.token_offsets,
615
616
617
618
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
619
            batch.all_decoder_input_ids,
620
621
622
623
624
625
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
626
627
            offset,
            token_offset,
628
629
630
631
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
632
            all_decoder_input_ids,
633
634
        ) in enumerate(iterator):
            # Select next token
635
            next_token_id, logprobs = next_token_chooser(
636
                all_decoder_input_ids.view(1, -1), logits
637
            )
638
639

            # Append next token to decoder tokens
640
641
642
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
643
644
            new_decoder_input_length = decoder_input_length + 1

645
646
647
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
648
            next_token_text, offset, token_offset = self.decode_token(
649
                all_decoder_input_ids, offset, token_offset
650
            )
651
652

            # Evaluate stopping criteria
653
654
            stop, reason = stopping_criteria(next_token_id, next_token_text)

655
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
656
657
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
658
                output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
659
660
661
662
663
664
665

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

666
667
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
668
669
                )
            else:
670
671
                # Keep request in the batch
                generated_text = None
672
                stopped = False
673

674
675
676
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
677
678
679
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
680
681
682
683
684
685
686
687
688
689
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
690
                next_token_id_squeezed.item() in self.all_special_ids,
691
692
693
694
695
                generated_text,
            )

            generations.append(generation)

696
697
698
699
700
701
702
703
704
705
706
707
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

708
        # We finished all generations in the batch; there is no next batch
709
        if stopped:
710
            return generations, None
711

712
713
714
715
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
716
        # Update decoder_attention_mask as we added a new token to input_ids
717
718
719
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
720

721
        return generations, batch