seq2seq_lm.py 26.9 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
    input_ids: Optional[torch.Tensor]
29
30
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    prefix_offsets: List[int]
    read_offsets: List[int]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
58
59
    # Maximum number of tokens this batch will grow to
    max_tokens: int

60
61
62
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
63
            id=self.batch_id,
64
            request_ids=[r.id for r in self.requests],
65
66
            size=len(self),
            max_tokens=self.max_tokens,
67
68
69
70
        )

    @classmethod
    def from_pb(
71
72
73
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
74
        dtype: torch.dtype,
75
        device: torch.device,
76
    ) -> "Seq2SeqLMBatch":
77
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
78
79
80
81
82
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
83
84
        prefix_offsets = []
        read_offsets = []
85
        requests_idx_mapping = {}
86
87

        # Parse batch
88
        max_truncation = 0
89
        padding_right_offset = 0
90
        max_decode_tokens = 0
91
        for i, r in enumerate(pb.requests):
92
            inputs.append(r.inputs)
93
            requests_idx_mapping[r.id] = i
94
            decoder_input_lengths.append(1)
95
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
96
97
98
99
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
100
            max_truncation = max(max_truncation, r.truncate)
101
            max_decode_tokens += stopping_criteria.max_new_tokens
102
103
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
104
105
            )

OlivierDehaene's avatar
OlivierDehaene committed
106
        # Tokenize batch
107
        tokenized_inputs = tokenizer(
108
109
110
            inputs,
            return_tensors="pt",
            padding=True,
111
            return_token_type_ids=False,
112
113
            truncation=True,
            max_length=max_truncation,
114
        ).to(device)
115
116
117
118

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

119
120
121
122
123
124
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
125
126
127
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
128
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
129

130
131
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

132
133
134
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
135
            requests_idx_mapping=requests_idx_mapping,
136
137
138
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
139
            all_decoder_input_ids=list(all_decoder_input_ids),
140
141
142
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
143
            input_lengths=input_lengths.tolist(),
144
            decoder_input_lengths=decoder_input_lengths,
145
146
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
147
148
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
149
            max_input_length=max_input_length.item(),
150
            max_decoder_input_length=1,
151
            padding_right_offset=padding_right_offset,
152
            max_tokens=max_tokens,
153
154
        )

155
    @tracer.start_as_current_span("filter")
156
157
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
158
            raise ValueError("Batch must have at least one request")
159
        if len(request_ids) == len(self):
160
161
162
163
164
165
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
166
        requests = []
167
168
        input_lengths = []
        decoder_input_lengths = []
169
170
        prefix_offsets = []
        read_offsets = []
171
172
173
174
175
176
177
178

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0
179
        padding_right_offset = 0
180

181
        total_remaining_decode_tokens = 0
182

183
184
185
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
186
187
            keep_indices.append(idx)

188
            requests.append(self.requests[idx])
189
190
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
191
192
193
194
195
196
197
198
199
200
201
202
203
204

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
205
206
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
207
            remaining_decode_tokens = (
208
209
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
210
211
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
212
213

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
214
215
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
216
        if self.decoder_attention_mask is not None:
217
            self.decoder_attention_mask = self.decoder_attention_mask[
218
219
220
221
222
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
223
224
            ]

225
226
227
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
228
229
230

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
231
232
233
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
234
235
236
237
238
239
240
241

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

242
        max_tokens = (
243
            len(request_ids) * (max_input_length + max_decoder_input_length)
244
245
246
            + remaining_decode_tokens
        )

247
248
249
250
251
252
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
253
254
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
255
256
257
258
259
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
260
        self.max_tokens = max_tokens
261
262

        return self
263

264
    @classmethod
265
    @tracer.start_as_current_span("concatenate")
266
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
267
268
        """Concatenate multiple batches together by padding internal torch tensors"""

269
        # Used for padding
270
271
272
273
274
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
275
            total_batch_size += len(batch)
276
277
278
279
280
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
281
282
283

        # Batch attributes
        requests = []
284
285
        requests_idx_mapping = {}
        all_decoder_input_ids = []
286
287
        input_lengths = []
        decoder_input_lengths = []
288
289
        prefix_offsets = []
        read_offsets = []
290
291
        next_token_choosers = []
        stopping_criterias = []
292
        max_tokens = 0
293

OlivierDehaene's avatar
OlivierDehaene committed
294
        # Batch tensors
295
296
297
298
299
300
301
302
303
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
304

305
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
306
            # Extend all list attributes
307
            requests.extend(batch.requests)
308
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
309
310
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
311
312
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
313
314
315
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

316
317
318
319
320
321
322
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

323
            # Slicing end index for this batch
324
            end_index = start_index + len(batch)
325
326
327
328
329

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
330
            # Create padded tensor
331
            if attention_mask is None:
332
                attention_mask = batch.attention_mask.new_zeros(
333
334
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
335
            # Copy to correct indices
336
337
338
339
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
340
            # Create padded tensor
341
            if decoder_input_ids is None:
342
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
343
                    (total_batch_size, 1),
344
                )
OlivierDehaene's avatar
OlivierDehaene committed
345
            # Copy to correct indices
346
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
347

OlivierDehaene's avatar
OlivierDehaene committed
348
            # Create padded tensor
349
            if decoder_attention_mask is None:
350
351
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
352
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
353
                )
OlivierDehaene's avatar
OlivierDehaene committed
354
355
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
356
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
357
358
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
359
360
                    start_index:end_index,
                    left_offset:-padding_right_offset,
361
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
362
            # If it exists, we need to index
363
            else:
364
365
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
366
367
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
368
                )
369
                decoder_attention_mask[
370
371
372
373
374
375
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
376

OlivierDehaene's avatar
OlivierDehaene committed
377
            # Create padded tensor
378
            if encoder_last_hidden_state is None:
379
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
380
381
382
383
384
385
386
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
387
            # Copy to correct indices
388
            encoder_last_hidden_state[
389
390
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
391
            batch.encoder_last_hidden_state = None
392

393
394
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
395
396
397
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
398

399
400
401
402
403
404
405
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
406

407
408
            start_index = end_index

409
410
411
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
412

413
414
415
416
417
418
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
419

420
421
422
423
424
425
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
426

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
446
447
448
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
468
469
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
470
                    del t
471

472
                    start_index = end_index
473
474
475
476

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
477
            requests_idx_mapping=requests_idx_mapping,
478
            input_ids=None,
479
480
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
481
            all_decoder_input_ids=all_decoder_input_ids,
482
483
484
485
486
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
487
488
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
489
490
491
492
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
493
            padding_right_offset=padding_right_offset,
494
            max_tokens=max_tokens,
495
496
        )

497
498
499
    def __len__(self):
        return len(self.requests)

500
501

class Seq2SeqLM(Model):
502
503
504
505
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
506
        quantize: Optional[str] = None,
507
        trust_remote_code: bool = False,
508
    ):
509
510
        if torch.cuda.is_available():
            device = torch.device("cuda")
511
            dtype = torch.float16
512
        else:
513
514
515
            if quantize:
                raise ValueError("quantization is not available on CPU")

516
517
518
            device = torch.device("cpu")
            dtype = torch.float32

519
        model = AutoModelForSeq2SeqLM.from_pretrained(
520
            model_id,
521
            revision=revision,
522
            torch_dtype=dtype,
523
524
525
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
526
            load_in_8bit=quantize == "bitsandbytes",
527
            trust_remote_code=trust_remote_code,
528
        )
529
530
531
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

532
        tokenizer = AutoTokenizer.from_pretrained(
533
534
535
536
537
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
538
        )
539
        tokenizer.bos_token_id = model.config.decoder_start_token_id
540
541

        super(Seq2SeqLM, self).__init__(
542
            model=model,
543
544
545
546
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
547
548
549
550
551
552
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

553
    def decode(self, decoder_ids: List[int]) -> str:
554
555
556
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
557

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
577
            encoder_outputs=encoder_last_hidden_state,
578
579
580
581
582
583
584
585
586
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

587
    @tracer.start_as_current_span("generate_token")
588
589
    def generate_token(
        self, batch: Seq2SeqLMBatch
590
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
591
592
593
594
595
596
597
598
599
600
601
602
603
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
604
            encoder_last_hidden_state = None
605

606
607
608
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
609
            batch.decoder_input_ids,
610
611
            decoder_attention_mask,
            encoder_last_hidden_state,
612
            batch.past_key_values,
613
614
615
        )

        # Finished requests
616
        generations: List[Generation] = []
617
        stopped = True
618
619
620
621
622

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
623
624
            batch.prefix_offsets,
            batch.read_offsets,
625
626
627
628
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
629
            batch.all_decoder_input_ids,
630
631
632
633
634
635
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
636
637
            prefix_offset,
            read_offset,
638
639
640
641
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
642
            all_decoder_input_ids,
643
644
        ) in enumerate(iterator):
            # Select next token
645
            next_token_id, logprobs = next_token_chooser(
646
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
647
            )
648
649

            # Append next token to decoder tokens
650
651
652
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
653
654
            new_decoder_input_length = decoder_input_length + 1

655
656
657
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
658
659
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
660
            )
661
662

            # Evaluate stopping criteria
663
664
            stop, reason = stopping_criteria(next_token_id, next_token_text)

665
            if not stop:
666
                stopped = False
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
                    output_text = self.decode(
                        all_decoder_input_ids[-decoder_input_length:]
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    prefill_tokens = PrefillTokens(
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
708
709
                )

710
                generations.append(generation)
711

712
713
714
715
716
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
717
718
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
719
720
721
722
723
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

724
        # We finished all generations in the batch; there is no next batch
725
        if stopped:
726
            return generations, None
727

728
729
730
731
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
732
        # Update decoder_attention_mask as we added a new token to input_ids
733
734
735
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
736

737
        return generations, batch