seq2seq_lm.py 29.6 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
from text_generation_server.utils.tokens import batch_top_tokens
2
3
4
import torch

from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
10
11
12
13
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
Nicolas Patry's avatar
Nicolas Patry committed
14
    Tokens,
15
16
17
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
18

19
20
tracer = trace.get_tracer(__name__)

21
22

@dataclass
23
class Seq2SeqLMBatch(Batch):
24
25
    batch_id: int
    requests: List[generate_pb2.Request]
26
    requests_idx_mapping: Dict[int, int]
27

OlivierDehaene's avatar
OlivierDehaene committed
28
    # Encoder values
29
    input_ids: Optional[torch.Tensor]
30
31
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
32
    # Decoder values
33
34
35
36
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

37
38
39
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
40
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
41
42
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
43
    # Lengths of all generations present in the batch
44
45
    input_lengths: List[int]
    decoder_input_lengths: List[int]
46
47
    prefix_offsets: List[int]
    read_offsets: List[int]
48

OlivierDehaene's avatar
OlivierDehaene committed
49
    # Generation helpers
50
51
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
52
53
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
54

OlivierDehaene's avatar
OlivierDehaene committed
55
    # Metadata used for padding
56
57
    max_input_length: int
    max_decoder_input_length: int
58
    padding_right_offset: int
59

60
61
62
    # Maximum number of tokens this batch will grow to
    max_tokens: int

63
64
65
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
66
            id=self.batch_id,
67
            request_ids=[r.id for r in self.requests],
68
69
            size=len(self),
            max_tokens=self.max_tokens,
70
71
72
73
        )

    @classmethod
    def from_pb(
74
75
76
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
77
        dtype: torch.dtype,
78
        device: torch.device,
79
    ) -> "Seq2SeqLMBatch":
80
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
81
82
83
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
84
        top_n_tokens = []
85
        decoder_input_lengths = []
86
87
        prefix_offsets = []
        read_offsets = []
88
        requests_idx_mapping = {}
89
90

        # Parse batch
91
        max_truncation = 0
92
        padding_right_offset = 0
93
        max_decode_tokens = 0
94
        for i, r in enumerate(pb.requests):
95
            inputs.append(r.inputs)
96
            requests_idx_mapping[r.id] = i
97
            decoder_input_lengths.append(1)
98
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
99
100
101
102
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
103
            top_n_tokens.append(r.top_n_tokens)
104
            max_truncation = max(max_truncation, r.truncate)
105
            max_decode_tokens += stopping_criteria.max_new_tokens
106
107
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
108
109
            )

OlivierDehaene's avatar
OlivierDehaene committed
110
        # Tokenize batch
111
        tokenized_inputs = tokenizer(
112
113
114
            inputs,
            return_tensors="pt",
            padding=True,
115
            return_token_type_ids=False,
116
117
            truncation=True,
            max_length=max_truncation,
118
        ).to(device)
119
120
121
122

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

123
124
125
126
127
128
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
129
130
131
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
132
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
Nicolas Patry's avatar
Nicolas Patry committed
133
134
135
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
136

137
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
138

139
140
141
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
142
            requests_idx_mapping=requests_idx_mapping,
143
144
145
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
146
            all_decoder_input_ids=list(all_decoder_input_ids),
147
148
149
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
150
            input_lengths=input_lengths.tolist(),
151
            decoder_input_lengths=decoder_input_lengths,
152
153
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
154
155
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
156
157
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
158
            max_input_length=max_input_length.item(),
159
            max_decoder_input_length=1,
160
            padding_right_offset=padding_right_offset,
161
            max_tokens=max_tokens,
162
163
        )

164
    @tracer.start_as_current_span("filter")
165
166
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
167
            raise ValueError("Batch must have at least one request")
168
        if len(request_ids) == len(self):
169
170
171
172
173
174
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
175
        requests = []
176
177
        input_lengths = []
        decoder_input_lengths = []
178
179
        prefix_offsets = []
        read_offsets = []
180
181
182
183
184

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
185
        top_n_tokens = []
186
187
188

        max_input_length = 0
        max_decoder_input_length = 0
189
        padding_right_offset = 0
190

191
        total_remaining_decode_tokens = 0
192

193
194
195
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
196
197
            keep_indices.append(idx)

198
            requests.append(self.requests[idx])
199
200
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
201
202
203
204
205
206
207
208
209
210
211
212
213
214

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
215
216
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
217
            top_n_tokens.append(self.top_n_tokens[idx])
218
            remaining_decode_tokens = (
219
220
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
221
222
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
223
224

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
225
226
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
227
        if self.decoder_attention_mask is not None:
228
            self.decoder_attention_mask = self.decoder_attention_mask[
229
230
231
232
233
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
234
235
            ]

236
237
238
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
239
240
241

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
242
243
244
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
245
246
247
248
249
250
251
252

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

Nicolas Patry's avatar
Nicolas Patry committed
253
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
254
        max_tokens = (
255
            len(request_ids) * (max_input_length + max_decoder_input_length)
256
257
258
            + remaining_decode_tokens
        )

259
260
261
262
263
264
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
265
266
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
267
268
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
269
270
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
271
272
273
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
274
        self.max_tokens = max_tokens
275
276

        return self
277

278
    @classmethod
279
    @tracer.start_as_current_span("concatenate")
280
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
281
282
        """Concatenate multiple batches together by padding internal torch tensors"""

283
        # Used for padding
284
285
286
287
288
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
289
            total_batch_size += len(batch)
290
291
292
293
294
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
295
296
297

        # Batch attributes
        requests = []
298
299
        requests_idx_mapping = {}
        all_decoder_input_ids = []
300
301
        input_lengths = []
        decoder_input_lengths = []
302
303
        prefix_offsets = []
        read_offsets = []
304
305
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
306
        top_n_tokens = []
307
        max_tokens = 0
308

OlivierDehaene's avatar
OlivierDehaene committed
309
        # Batch tensors
310
311
312
313
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
Nicolas Patry's avatar
Nicolas Patry committed
314
        top_n_tokens_tensor = None
315
316
317
318
319
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
320

321
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
322
            # Extend all list attributes
323
            requests.extend(batch.requests)
324
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
325
326
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
327
328
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
329
330
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
331
            top_n_tokens.extend(batch.top_n_tokens)
332

333
334
335
336
337
338
339
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

340
            # Slicing end index for this batch
341
            end_index = start_index + len(batch)
342
343
344
345
346

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
347
            # Create padded tensor
348
            if attention_mask is None:
349
                attention_mask = batch.attention_mask.new_zeros(
350
351
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
352
            # Copy to correct indices
353
354
355
356
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
357
            # Create padded tensor
358
            if decoder_input_ids is None:
359
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
360
                    (total_batch_size, 1),
361
                )
OlivierDehaene's avatar
OlivierDehaene committed
362
            # Copy to correct indices
363
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
364

OlivierDehaene's avatar
OlivierDehaene committed
365
            # Create padded tensor
366
            if decoder_attention_mask is None:
367
368
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
369
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
370
                )
OlivierDehaene's avatar
OlivierDehaene committed
371
372
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
373
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
374
375
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
376
377
                    start_index:end_index,
                    left_offset:-padding_right_offset,
378
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
379
            # If it exists, we need to index
380
            else:
381
382
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
383
384
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
385
                )
386
                decoder_attention_mask[
387
388
389
390
391
392
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
393

OlivierDehaene's avatar
OlivierDehaene committed
394
            # Create padded tensor
395
            if encoder_last_hidden_state is None:
396
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
397
398
399
400
401
402
403
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

Nicolas Patry's avatar
Nicolas Patry committed
404
405
406
407
408
409
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

OlivierDehaene's avatar
OlivierDehaene committed
410
            # Copy to correct indices
411
            encoder_last_hidden_state[
412
413
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
414
            batch.encoder_last_hidden_state = None
415

416
417
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
418
419
420
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
421

422
423
424
425
426
427
428
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
429

430
431
            start_index = end_index

432
433
434
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
435

436
437
438
439
440
441
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
442

443
444
445
446
447
448
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
469
470
471
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
491
492
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
493
                    del t
494

495
                    start_index = end_index
496
497
498
499

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
500
            requests_idx_mapping=requests_idx_mapping,
501
            input_ids=None,
502
503
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
504
            all_decoder_input_ids=all_decoder_input_ids,
505
506
507
508
509
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
510
511
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
512
513
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
514
515
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
516
517
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
518
            padding_right_offset=padding_right_offset,
519
            max_tokens=max_tokens,
520
521
        )

522
523
524
    def __len__(self):
        return len(self.requests)

525
526

class Seq2SeqLM(Model):
527
528
529
530
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
531
        quantize: Optional[str] = None,
532
        dtype: Optional[torch.dtype] = None,
533
        trust_remote_code: bool = False,
534
    ):
535
536
        if torch.cuda.is_available():
            device = torch.device("cuda")
537
            dtype = torch.float16 if dtype is None else dtype
538
        else:
539
540
541
            if quantize:
                raise ValueError("quantization is not available on CPU")

542
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
543
            dtype = torch.float32 if dtype is None else dtype
544

545
        model = AutoModelForSeq2SeqLM.from_pretrained(
546
            model_id,
547
            revision=revision,
548
            torch_dtype=dtype,
549
550
551
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
552
            load_in_8bit=quantize == "bitsandbytes",
553
            trust_remote_code=trust_remote_code,
554
        )
555
556
557
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

558
        tokenizer = AutoTokenizer.from_pretrained(
559
560
561
562
563
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
564
        )
565
        tokenizer.bos_token_id = model.config.decoder_start_token_id
566
567

        super(Seq2SeqLM, self).__init__(
568
            model=model,
569
570
571
572
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
573
574
575
576
577
578
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

579
    def decode(self, decoder_ids: List[int]) -> str:
580
581
582
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
603
            encoder_outputs=encoder_last_hidden_state,
604
605
606
607
608
609
610
611
612
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

613
    @tracer.start_as_current_span("generate_token")
614
615
    def generate_token(
        self, batch: Seq2SeqLMBatch
616
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
617
618
619
620
621
622
623
624
625
626
627
628
629
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
630
            encoder_last_hidden_state = None
631

632
633
634
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
635
            batch.decoder_input_ids,
636
637
            decoder_attention_mask,
            encoder_last_hidden_state,
638
            batch.past_key_values,
639
640
        )

Nicolas Patry's avatar
Nicolas Patry committed
641
642
643
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
644
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
645
646
        )

647
        # Finished requests
648
        generations: List[Generation] = []
649
        stopped = True
650
651
652
653
654

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
655
656
            batch.prefix_offsets,
            batch.read_offsets,
657
658
659
660
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
661
            batch.all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
662
663
664
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
665
666
667
668
669
670
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
671
672
            prefix_offset,
            read_offset,
673
674
675
676
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
677
            all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
678
679
680
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
681
682
        ) in enumerate(iterator):
            # Select next token
683
            next_token_id, logprobs = next_token_chooser(
684
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
685
            )
686
687

            # Append next token to decoder tokens
688
689
690
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
691
692
            new_decoder_input_length = decoder_input_length + 1

693
694
695
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
696
697
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
698
            )
699
700

            # Evaluate stopping criteria
701
702
            stop, reason = stopping_criteria(next_token_id, next_token_text)

703
            if not stop:
704
                stopped = False
705

706
707
708
709
710
711
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
712
713
                    output_text, _, _ = self.decode_token(
                        all_decoder_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
714
715
716
                        prefix_offset=len(all_decoder_input_ids)
                        - decoder_input_length
                        - 1,
717
                        read_offset=len(all_decoder_input_ids) - decoder_input_length,
OlivierDehaene's avatar
OlivierDehaene committed
718
                        skip_special_tokens=True,
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
734
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
Nicolas Patry's avatar
Nicolas Patry committed
735
                    prefill_tokens = Tokens(
736
737
738
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
OlivierDehaene's avatar
OlivierDehaene committed
739
                        [False],
740
741
742
743
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
744
745
746
747
748
749
750
751
752
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
Nicolas Patry's avatar
Nicolas Patry committed
753
                    top_tokens = Tokens(
Nicolas Patry's avatar
Nicolas Patry committed
754
755
756
757
758
759
760
761
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

762
763
764
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
765
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
766
767
768
769
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
Nicolas Patry's avatar
Nicolas Patry committed
770
                    ),
771
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
772
                    top_tokens,
773
774
                )

775
                generations.append(generation)
776

777
778
779
780
781
            # Update values
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
782
783
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
784
785
786
787
788
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

789
        # We finished all generations in the batch; there is no next batch
790
        if stopped:
791
            return generations, None
792

793
794
795
796
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
797
        # Update decoder_attention_mask as we added a new token to input_ids
798
799
800
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
801

802
        return generations, batch