seq2seq_lm.py 26.7 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
26

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Encoder values
28
    input_ids: Optional[torch.Tensor]
29
30
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
31
    # Decoder values
32
33
34
35
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

36
37
38
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
40
41
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Lengths of all generations present in the batch
43
44
    input_lengths: List[int]
    decoder_input_lengths: List[int]
45
46
    prefix_offsets: List[int]
    read_offsets: List[int]
47

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Generation helpers
49
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
52
    # Metadata used for padding
53
54
    max_input_length: int
    max_decoder_input_length: int
55
    padding_right_offset: int
56

57
58
59
    # Maximum number of tokens this batch will grow to
    max_tokens: int

60
    def to_pb(self) -> generate_pb2.Batch:
61
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
62
        return generate_pb2.Batch(
63
64
65
66
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
67
68
69
70
        )

    @classmethod
    def from_pb(
71
72
73
74
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
75
    ) -> "Seq2SeqLMBatch":
76
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
77
78
79
80
81
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_lengths = []
82
83
        prefix_offsets = []
        read_offsets = []
84
        requests_idx_mapping = {}
85
86

        # Parse batch
87
        max_truncation = 0
88
        padding_right_offset = 0
89
        max_decode_tokens = 0
90
        for i, r in enumerate(pb.requests):
91
            inputs.append(r.inputs)
92
            requests_idx_mapping[r.id] = i
93
            decoder_input_lengths.append(1)
94
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
95
96
97
98
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
99
            max_truncation = max(max_truncation, r.truncate)
100
            max_decode_tokens += stopping_criteria.max_new_tokens
101
102
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
103
104
            )

OlivierDehaene's avatar
OlivierDehaene committed
105
        # Tokenize batch
106
        tokenized_inputs = tokenizer(
107
108
109
            inputs,
            return_tensors="pt",
            padding=True,
110
            return_token_type_ids=False,
111
112
            truncation=True,
            max_length=max_truncation,
113
        ).to(device)
114
115
116
117

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

118
119
120
121
122
123
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
124
125
126
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
127
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
128

129
130
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

131
132
133
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
134
            requests_idx_mapping=requests_idx_mapping,
135
136
137
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
138
            all_decoder_input_ids=list(all_decoder_input_ids),
139
140
141
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
142
            input_lengths=input_lengths.tolist(),
143
            decoder_input_lengths=decoder_input_lengths,
144
145
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
146
147
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
148
            max_input_length=max_input_length.item(),
149
            max_decoder_input_length=1,
150
            padding_right_offset=padding_right_offset,
151
            max_tokens=max_tokens,
152
153
        )

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    @tracer.start_as_current_span("filter")
    def filter(
        self, requests: List[generate_pb2.Request]
    ) -> Optional["Seq2SeqLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        decoder_input_lengths = []
169
170
        prefix_offsets = []
        read_offsets = []
171
172
173
174
175
176
177
178

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []

        max_input_length = 0
        max_decoder_input_length = 0
179
        padding_right_offset = 0
180

181
        total_remaining_decode_tokens = 0
182

183
184
185
186
187
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

188
189
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
190
191
192
193
194
195
196
197
198
199
200
201
202
203

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
204
205
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
206
            remaining_decode_tokens = (
207
208
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
209
210
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
211
212

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
213
214
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
215
        if self.decoder_attention_mask is not None:
216
            self.decoder_attention_mask = self.decoder_attention_mask[
217
218
219
220
221
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
222
223
            ]

224
225
226
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
227
228
229

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
230
231
232
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
233
234
235
236
237
238
239
240

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

241
242
243
244
245
        max_tokens = (
            len(requests) * (max_input_length + max_decoder_input_length)
            + remaining_decode_tokens
        )

246
247
248
249
250
251
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
252
253
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
254
255
256
257
258
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
259
        self.max_tokens = max_tokens
260
261

        return self
262

263
    @classmethod
264
    @tracer.start_as_current_span("concatenate")
265
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
266
267
        """Concatenate multiple batches together by padding internal torch tensors"""

268
        # Used for padding
269
270
271
272
273
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
274
            total_batch_size += len(batch)
275
276
277
278
279
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
280
281
282

        # Batch attributes
        requests = []
283
284
        requests_idx_mapping = {}
        all_decoder_input_ids = []
285
286
        input_lengths = []
        decoder_input_lengths = []
287
288
        prefix_offsets = []
        read_offsets = []
289
290
        next_token_choosers = []
        stopping_criterias = []
291
        max_tokens = 0
292

OlivierDehaene's avatar
OlivierDehaene committed
293
        # Batch tensors
294
295
296
297
298
299
300
301
302
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
303

304
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
305
            # Extend all list attributes
306
            requests.extend(batch.requests)
307
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
308
309
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
310
311
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
312
313
314
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

315
316
317
318
319
320
321
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

322
            # Slicing end index for this batch
323
            end_index = start_index + len(batch)
324
325
326
327
328

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
329
            # Create padded tensor
330
            if attention_mask is None:
331
                attention_mask = batch.attention_mask.new_zeros(
332
333
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
334
            # Copy to correct indices
335
336
337
338
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
339
            # Create padded tensor
340
            if decoder_input_ids is None:
341
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
342
                    (total_batch_size, 1),
343
                )
OlivierDehaene's avatar
OlivierDehaene committed
344
            # Copy to correct indices
345
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
346

OlivierDehaene's avatar
OlivierDehaene committed
347
            # Create padded tensor
348
            if decoder_attention_mask is None:
349
350
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
351
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
352
                )
OlivierDehaene's avatar
OlivierDehaene committed
353
354
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
355
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
356
357
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
358
359
                    start_index:end_index,
                    left_offset:-padding_right_offset,
360
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
361
            # If it exists, we need to index
362
            else:
363
364
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
365
366
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
367
                )
368
                decoder_attention_mask[
369
370
371
372
373
374
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
375

OlivierDehaene's avatar
OlivierDehaene committed
376
            # Create padded tensor
377
            if encoder_last_hidden_state is None:
378
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
379
380
381
382
383
384
385
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
386
            # Copy to correct indices
387
            encoder_last_hidden_state[
388
389
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
390
            batch.encoder_last_hidden_state = None
391

392
393
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
394
395
396
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
397

398
399
400
401
402
403
404
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
405

406
407
            start_index = end_index

408
409
410
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
411

412
413
414
415
416
417
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
418

419
420
421
422
423
424
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
445
446
447
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
467
468
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
469
                    del t
470

471
                    start_index = end_index
472
473
474
475

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
476
            requests_idx_mapping=requests_idx_mapping,
477
            input_ids=None,
478
479
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
480
            all_decoder_input_ids=all_decoder_input_ids,
481
482
483
484
485
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
486
487
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
488
489
490
491
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
492
            padding_right_offset=padding_right_offset,
493
            max_tokens=max_tokens,
494
495
        )

496
497
498
    def __len__(self):
        return len(self.requests)

499
500

class Seq2SeqLM(Model):
501
502
503
504
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
505
        quantize: Optional[str] = None,
506
        trust_remote_code: bool = False,
507
    ):
508
509
        if torch.cuda.is_available():
            device = torch.device("cuda")
510
            dtype = torch.float16
511
        else:
512
513
514
            if quantize:
                raise ValueError("quantization is not available on CPU")

515
516
517
            device = torch.device("cpu")
            dtype = torch.float32

518
        model = AutoModelForSeq2SeqLM.from_pretrained(
519
            model_id,
520
            revision=revision,
521
            torch_dtype=dtype,
522
523
524
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
525
            load_in_8bit=quantize == "bitsandbytes",
526
            trust_remote_code=trust_remote_code,
527
        )
528
529
530
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

531
        tokenizer = AutoTokenizer.from_pretrained(
532
533
534
535
536
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
537
        )
538
        tokenizer.bos_token_id = model.config.decoder_start_token_id
539
540

        super(Seq2SeqLM, self).__init__(
541
            model=model,
542
543
544
545
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
546
547
548
549
550
551
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

552
    def decode(self, decoder_ids: List[int]) -> str:
553
554
555
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
576
            encoder_outputs=encoder_last_hidden_state,
577
578
579
580
581
582
583
584
585
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

586
    @tracer.start_as_current_span("generate_token")
587
588
    def generate_token(
        self, batch: Seq2SeqLMBatch
589
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
590
591
592
593
594
595
596
597
598
599
600
601
602
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
603
            encoder_last_hidden_state = None
604

605
606
607
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
608
            batch.decoder_input_ids,
609
610
            decoder_attention_mask,
            encoder_last_hidden_state,
611
            batch.past_key_values,
612
613
614
        )

        # Finished requests
615
        generations: List[Generation] = []
616
        stopped = True
617
618
619
620
621

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
622
623
            batch.prefix_offsets,
            batch.read_offsets,
624
625
626
627
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
628
            batch.all_decoder_input_ids,
629
630
631
632
633
634
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
635
636
            prefix_offset,
            read_offset,
637
638
639
640
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
641
            all_decoder_input_ids,
642
643
        ) in enumerate(iterator):
            # Select next token
644
            next_token_id, logprobs = next_token_chooser(
645
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
646
            )
647
648

            # Append next token to decoder tokens
649
650
651
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
652
653
            new_decoder_input_length = decoder_input_length + 1

654
655
656
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
657
658
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
659
            )
660
661

            # Evaluate stopping criteria
662
663
            stop, reason = stopping_criteria(next_token_id, next_token_text)

664
            if not stop:
665
                stopped = False
666

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
                    output_text = self.decode(
                        all_decoder_input_ids[-decoder_input_length:]
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    prefill_tokens = PrefillTokens(
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
707
708
                )

709
                generations.append(generation)
710

711
712
713
714
715
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
716
717
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
718
719
720
721
722
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

723
        # We finished all generations in the batch; there is no next batch
724
        if stopped:
725
            return generations, None
726

727
728
729
730
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
731
        # Update decoder_attention_mask as we added a new token to input_ids
732
733
734
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
735

736
        return generations, batch