seq2seq_lm.py 29.2 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
from text_generation_server.utils.tokens import batch_top_tokens
2
3
4
import torch

from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
10
11
12
13
14
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
Nicolas Patry's avatar
Nicolas Patry committed
15
    TopTokens,
16
17
18
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
19

20
21
tracer = trace.get_tracer(__name__)

22
23

@dataclass
24
class Seq2SeqLMBatch(Batch):
25
26
    batch_id: int
    requests: List[generate_pb2.Request]
27
    requests_idx_mapping: Dict[int, int]
28

OlivierDehaene's avatar
OlivierDehaene committed
29
    # Encoder values
30
    input_ids: Optional[torch.Tensor]
31
32
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
33
    # Decoder values
34
35
36
37
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

38
39
40
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
41
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
42
43
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
44
    # Lengths of all generations present in the batch
45
46
    input_lengths: List[int]
    decoder_input_lengths: List[int]
47
48
    prefix_offsets: List[int]
    read_offsets: List[int]
49

OlivierDehaene's avatar
OlivierDehaene committed
50
    # Generation helpers
51
52
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
53
54
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
55

OlivierDehaene's avatar
OlivierDehaene committed
56
    # Metadata used for padding
57
58
    max_input_length: int
    max_decoder_input_length: int
59
    padding_right_offset: int
60

61
62
63
    # Maximum number of tokens this batch will grow to
    max_tokens: int

64
65
66
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
67
            id=self.batch_id,
68
            request_ids=[r.id for r in self.requests],
69
70
            size=len(self),
            max_tokens=self.max_tokens,
71
72
73
74
        )

    @classmethod
    def from_pb(
75
76
77
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
78
        dtype: torch.dtype,
79
        device: torch.device,
80
    ) -> "Seq2SeqLMBatch":
81
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
82
83
84
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
85
        top_n_tokens = []
86
        decoder_input_lengths = []
87
88
        prefix_offsets = []
        read_offsets = []
89
        requests_idx_mapping = {}
90
91

        # Parse batch
92
        max_truncation = 0
93
        padding_right_offset = 0
94
        max_decode_tokens = 0
95
        for i, r in enumerate(pb.requests):
96
            inputs.append(r.inputs)
97
            requests_idx_mapping[r.id] = i
98
            decoder_input_lengths.append(1)
99
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
100
101
102
103
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
104
            top_n_tokens.append(r.top_n_tokens)
105
            max_truncation = max(max_truncation, r.truncate)
106
            max_decode_tokens += stopping_criteria.max_new_tokens
107
108
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
109
110
            )

OlivierDehaene's avatar
OlivierDehaene committed
111
        # Tokenize batch
112
        tokenized_inputs = tokenizer(
113
114
115
            inputs,
            return_tensors="pt",
            padding=True,
116
            return_token_type_ids=False,
117
118
            truncation=True,
            max_length=max_truncation,
119
        ).to(device)
120
121
122
123

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

124
125
126
127
128
129
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
130
131
132
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
133
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
Nicolas Patry's avatar
Nicolas Patry committed
134
135
136
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
137

138
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
139

140
141
142
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
143
            requests_idx_mapping=requests_idx_mapping,
144
145
146
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
147
            all_decoder_input_ids=list(all_decoder_input_ids),
148
149
150
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
151
            input_lengths=input_lengths.tolist(),
152
            decoder_input_lengths=decoder_input_lengths,
153
154
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
155
156
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
157
158
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
159
            max_input_length=max_input_length.item(),
160
            max_decoder_input_length=1,
161
            padding_right_offset=padding_right_offset,
162
            max_tokens=max_tokens,
163
164
        )

165
    @tracer.start_as_current_span("filter")
166
167
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
168
            raise ValueError("Batch must have at least one request")
169
        if len(request_ids) == len(self):
170
171
172
173
174
175
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
176
        requests = []
177
178
        input_lengths = []
        decoder_input_lengths = []
179
180
        prefix_offsets = []
        read_offsets = []
181
182
183
184
185

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
186
        top_n_tokens = []
187
188
189

        max_input_length = 0
        max_decoder_input_length = 0
190
        padding_right_offset = 0
191

192
        total_remaining_decode_tokens = 0
193

194
195
196
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
197
198
            keep_indices.append(idx)

199
            requests.append(self.requests[idx])
200
201
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
202
203
204
205
206
207
208
209
210
211
212
213
214
215

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
216
217
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
218
            top_n_tokens.append(self.top_n_tokens[idx])
219
            remaining_decode_tokens = (
220
221
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
222
223
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
224
225

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
226
227
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
228
        if self.decoder_attention_mask is not None:
229
            self.decoder_attention_mask = self.decoder_attention_mask[
230
231
232
233
234
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
235
236
            ]

237
238
239
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
240
241
242

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
243
244
245
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
246
247
248
249
250
251
252
253

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

Nicolas Patry's avatar
Nicolas Patry committed
254
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
255
        max_tokens = (
256
            len(request_ids) * (max_input_length + max_decoder_input_length)
257
258
259
            + remaining_decode_tokens
        )

260
261
262
263
264
265
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
266
267
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
268
269
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
270
271
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
272
273
274
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
275
        self.max_tokens = max_tokens
276
277

        return self
278

279
    @classmethod
280
    @tracer.start_as_current_span("concatenate")
281
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
282
283
        """Concatenate multiple batches together by padding internal torch tensors"""

284
        # Used for padding
285
286
287
288
289
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
290
            total_batch_size += len(batch)
291
292
293
294
295
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
296
297
298

        # Batch attributes
        requests = []
299
300
        requests_idx_mapping = {}
        all_decoder_input_ids = []
301
302
        input_lengths = []
        decoder_input_lengths = []
303
304
        prefix_offsets = []
        read_offsets = []
305
306
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
307
        top_n_tokens = []
308
        max_tokens = 0
309

OlivierDehaene's avatar
OlivierDehaene committed
310
        # Batch tensors
311
312
313
314
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
Nicolas Patry's avatar
Nicolas Patry committed
315
        top_n_tokens_tensor = None
316
317
318
319
320
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
321

322
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
323
            # Extend all list attributes
324
            requests.extend(batch.requests)
325
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
326
327
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
328
329
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
330
331
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
332
            top_n_tokens.extend(batch.top_n_tokens)
333

334
335
336
337
338
339
340
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

341
            # Slicing end index for this batch
342
            end_index = start_index + len(batch)
343
344
345
346
347

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
348
            # Create padded tensor
349
            if attention_mask is None:
350
                attention_mask = batch.attention_mask.new_zeros(
351
352
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
353
            # Copy to correct indices
354
355
356
357
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
358
            # Create padded tensor
359
            if decoder_input_ids is None:
360
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
361
                    (total_batch_size, 1),
362
                )
OlivierDehaene's avatar
OlivierDehaene committed
363
            # Copy to correct indices
364
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
365

OlivierDehaene's avatar
OlivierDehaene committed
366
            # Create padded tensor
367
            if decoder_attention_mask is None:
368
369
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
370
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
371
                )
OlivierDehaene's avatar
OlivierDehaene committed
372
373
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
374
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
375
376
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
377
378
                    start_index:end_index,
                    left_offset:-padding_right_offset,
379
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
380
            # If it exists, we need to index
381
            else:
382
383
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
384
385
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
386
                )
387
                decoder_attention_mask[
388
389
390
391
392
393
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
394

OlivierDehaene's avatar
OlivierDehaene committed
395
            # Create padded tensor
396
            if encoder_last_hidden_state is None:
397
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
398
399
400
401
402
403
404
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

Nicolas Patry's avatar
Nicolas Patry committed
405
406
407
408
409
410
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

OlivierDehaene's avatar
OlivierDehaene committed
411
            # Copy to correct indices
412
            encoder_last_hidden_state[
413
414
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
415
            batch.encoder_last_hidden_state = None
416

417
418
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
419
420
421
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
422

423
424
425
426
427
428
429
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
430

431
432
            start_index = end_index

433
434
435
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
436

437
438
439
440
441
442
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
443

444
445
446
447
448
449
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
470
471
472
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
492
493
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
494
                    del t
495

496
                    start_index = end_index
497
498
499
500

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
501
            requests_idx_mapping=requests_idx_mapping,
502
            input_ids=None,
503
504
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
505
            all_decoder_input_ids=all_decoder_input_ids,
506
507
508
509
510
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
511
512
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
513
514
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
515
516
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
517
518
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
519
            padding_right_offset=padding_right_offset,
520
            max_tokens=max_tokens,
521
522
        )

523
524
525
    def __len__(self):
        return len(self.requests)

526
527

class Seq2SeqLM(Model):
528
529
530
531
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
532
        quantize: Optional[str] = None,
533
        dtype: Optional[torch.dtype] = None,
534
        trust_remote_code: bool = False,
535
    ):
536
537
        if torch.cuda.is_available():
            device = torch.device("cuda")
538
            dtype = torch.float16 if dtype is None else dtype
539
        else:
540
541
542
            if quantize:
                raise ValueError("quantization is not available on CPU")

543
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
544
            dtype = torch.float32 if dtype is None else dtype
545

546
        model = AutoModelForSeq2SeqLM.from_pretrained(
547
            model_id,
548
            revision=revision,
549
            torch_dtype=dtype,
550
551
552
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
553
            load_in_8bit=quantize == "bitsandbytes",
554
            trust_remote_code=trust_remote_code,
555
        )
556
557
558
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

559
        tokenizer = AutoTokenizer.from_pretrained(
560
561
562
563
564
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
565
        )
566
        tokenizer.bos_token_id = model.config.decoder_start_token_id
567
568

        super(Seq2SeqLM, self).__init__(
569
            model=model,
570
571
572
573
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
574
575
576
577
578
579
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

580
    def decode(self, decoder_ids: List[int]) -> str:
581
582
583
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
584

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
604
            encoder_outputs=encoder_last_hidden_state,
605
606
607
608
609
610
611
612
613
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

614
    @tracer.start_as_current_span("generate_token")
615
616
    def generate_token(
        self, batch: Seq2SeqLMBatch
617
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
618
619
620
621
622
623
624
625
626
627
628
629
630
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
631
            encoder_last_hidden_state = None
632

633
634
635
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
636
            batch.decoder_input_ids,
637
638
            decoder_attention_mask,
            encoder_last_hidden_state,
639
            batch.past_key_values,
640
641
        )

Nicolas Patry's avatar
Nicolas Patry committed
642
643
644
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
645
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
646
647
        )

648
        # Finished requests
649
        generations: List[Generation] = []
650
        stopped = True
651
652
653
654
655

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
656
657
            batch.prefix_offsets,
            batch.read_offsets,
658
659
660
661
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
662
            batch.all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
663
664
665
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
666
667
668
669
670
671
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
672
673
            prefix_offset,
            read_offset,
674
675
676
677
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
678
            all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
679
680
681
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
682
683
        ) in enumerate(iterator):
            # Select next token
684
            next_token_id, logprobs = next_token_chooser(
685
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
686
            )
687
688

            # Append next token to decoder tokens
689
690
691
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
692
693
            new_decoder_input_length = decoder_input_length + 1

694
695
696
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
697
698
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
699
            )
700
701

            # Evaluate stopping criteria
702
703
            stop, reason = stopping_criteria(next_token_id, next_token_text)

704
            if not stop:
705
                stopped = False
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
                    output_text = self.decode(
                        all_decoder_input_ids[-decoder_input_length:]
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
730
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
731
732
733
734
735
736
737
738
                    prefill_tokens = PrefillTokens(
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
                    top_tokens = TopTokens(
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

757
758
759
760
761
762
763
764
                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
765
                    top_tokens,
766
767
                )

768
                generations.append(generation)
769

770
771
772
773
774
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
775
776
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
777
778
779
780
781
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

782
        # We finished all generations in the batch; there is no next batch
783
        if stopped:
784
            return generations, None
785

786
787
788
789
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
790
        # Update decoder_attention_mask as we added a new token to input_ids
791
792
793
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
794

795
        return generations, batch