seq2seq_lm.py 34 KB
Newer Older
1
import torch
2
import torch.distributed
3
import time
4
5

from dataclasses import dataclass
6
from opentelemetry import trace
7
8
9
10
11
12
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    PreTrainedTokenizerBase,
    AutoConfig,
)
13
from typing import Optional, Tuple, List, Type, Dict
14

15
16
17
18
19
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
Daniël de Kok's avatar
Daniël de Kok committed
20
from text_generation_server.utils.chunks import concat_text_chunks
21
from text_generation_server.utils.quantization import get_loader
22
from text_generation_server.utils.tokens import batch_top_tokens
23
24
25
26
27
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
Nicolas Patry's avatar
Nicolas Patry committed
28
    Tokens,
29
30
31
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
32

33
34
tracer = trace.get_tracer(__name__)

35
36

@dataclass
37
class Seq2SeqLMBatch(Batch):
38
39
    batch_id: int
    requests: List[generate_pb2.Request]
40
    requests_idx_mapping: Dict[int, int]
41

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Encoder values
43
    input_ids: Optional[torch.Tensor]
44
45
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
46
    # Decoder values
47
48
49
50
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

51
52
53
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
54
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
55
56
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
57
    # Lengths of all generations present in the batch
58
59
    input_lengths: List[int]
    decoder_input_lengths: List[int]
60
61
    prefix_offsets: List[int]
    read_offsets: List[int]
62

OlivierDehaene's avatar
OlivierDehaene committed
63
    # Generation helpers
64
65
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
66
67
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
68

OlivierDehaene's avatar
OlivierDehaene committed
69
    # Metadata used for padding
70
71
    max_input_length: int
    max_decoder_input_length: int
72
    padding_right_offset: int
73

74
75
76
    # Maximum number of tokens this batch will grow to
    max_tokens: int

77
78
79
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
80
            id=self.batch_id,
81
            request_ids=[r.id for r in self.requests],
82
83
            size=len(self),
            max_tokens=self.max_tokens,
84
85
86
87
        )

    @classmethod
    def from_pb(
88
89
90
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
91
        dtype: torch.dtype,
92
        device: torch.device,
93
    ) -> "Seq2SeqLMBatch":
94
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
95
96
97
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
98
        top_n_tokens = []
99
        decoder_input_lengths = []
100
101
        prefix_offsets = []
        read_offsets = []
102
        requests_idx_mapping = {}
103
104

        # Parse batch
105
        max_truncation = 0
106
        padding_right_offset = 0
107
        max_decode_tokens = 0
108
        for i, r in enumerate(pb.requests):
Daniël de Kok's avatar
Daniël de Kok committed
109
            inputs.append(concat_text_chunks(r.input_chunks.chunks))
110
            requests_idx_mapping[r.id] = i
111
            decoder_input_lengths.append(1)
OlivierDehaene's avatar
OlivierDehaene committed
112
113
114
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
115
116
117
118
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
119
            top_n_tokens.append(r.top_n_tokens)
120
            max_truncation = max(max_truncation, r.truncate)
121
            max_decode_tokens += stopping_criteria.max_new_tokens
122
123
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
124
125
            )

OlivierDehaene's avatar
OlivierDehaene committed
126
        # Tokenize batch
127
        tokenized_inputs = tokenizer(
128
129
130
            inputs,
            return_tensors="pt",
            padding=True,
131
            return_token_type_ids=False,
132
133
            truncation=True,
            max_length=max_truncation,
134
        ).to(device)
135
136
137
138

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

139
140
141
142
143
144
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
145
146
147
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
148
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
Nicolas Patry's avatar
Nicolas Patry committed
149
150
151
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
152

153
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
154

155
156
157
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
158
            requests_idx_mapping=requests_idx_mapping,
159
160
161
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
162
            all_decoder_input_ids=list(all_decoder_input_ids),
163
164
165
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
166
            input_lengths=input_lengths.tolist(),
167
            decoder_input_lengths=decoder_input_lengths,
168
169
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
170
171
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
172
173
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
174
            max_input_length=max_input_length.item(),
175
            max_decoder_input_length=1,
176
            padding_right_offset=padding_right_offset,
177
            max_tokens=max_tokens,
178
179
        )

180
    @tracer.start_as_current_span("filter")
181
182
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
183
            raise ValueError("Batch must have at least one request")
184
        if len(request_ids) == len(self):
185
186
187
188
189
190
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
191
        requests = []
192
193
        input_lengths = []
        decoder_input_lengths = []
194
195
        prefix_offsets = []
        read_offsets = []
196
197
198
199
200

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
201
        top_n_tokens = []
202
203
204

        max_input_length = 0
        max_decoder_input_length = 0
205
        padding_right_offset = 0
206

207
        total_remaining_decode_tokens = 0
208

209
210
211
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
212
213
            keep_indices.append(idx)

214
            requests.append(self.requests[idx])
215
216
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
217
218
219
220
221
222
223
224
225
226
227
228
229
230

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
231
232
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
233
            top_n_tokens.append(self.top_n_tokens[idx])
234
            remaining_decode_tokens = (
235
236
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
237
238
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
239
240

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
241
242
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
243
        if self.decoder_attention_mask is not None:
244
            self.decoder_attention_mask = self.decoder_attention_mask[
245
246
247
248
249
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
250
251
            ]

252
253
254
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
255
256
257

        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
258
259
260
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
261
262
263
264
265
266
267
268

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

Nicolas Patry's avatar
Nicolas Patry committed
269
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
270
        max_tokens = (
271
            len(request_ids) * (max_input_length + max_decoder_input_length)
272
273
274
            + remaining_decode_tokens
        )

275
276
277
278
279
280
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
281
282
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
283
284
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
285
286
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
287
288
289
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
290
        self.max_tokens = max_tokens
291
292

        return self
293

294
    @classmethod
295
    @tracer.start_as_current_span("concatenate")
296
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
297
298
        """Concatenate multiple batches together by padding internal torch tensors"""

299
        # Used for padding
300
301
302
303
304
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
305
            total_batch_size += len(batch)
306
307
308
309
310
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
311
312
313

        # Batch attributes
        requests = []
314
315
        requests_idx_mapping = {}
        all_decoder_input_ids = []
316
317
        input_lengths = []
        decoder_input_lengths = []
318
319
        prefix_offsets = []
        read_offsets = []
320
321
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
322
        top_n_tokens = []
323
        max_tokens = 0
324

OlivierDehaene's avatar
OlivierDehaene committed
325
        # Batch tensors
326
327
328
329
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
Nicolas Patry's avatar
Nicolas Patry committed
330
        top_n_tokens_tensor = None
331
332
333
334
335
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
336

337
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
338
            # Extend all list attributes
339
            requests.extend(batch.requests)
340
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
341
342
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
343
344
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
345
346
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
347
            top_n_tokens.extend(batch.top_n_tokens)
348

349
350
351
352
353
354
355
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

356
            # Slicing end index for this batch
357
            end_index = start_index + len(batch)
358
359
360
361
362

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
363
            # Create padded tensor
364
            if attention_mask is None:
365
                attention_mask = batch.attention_mask.new_zeros(
366
367
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
368
            # Copy to correct indices
OlivierDehaene's avatar
OlivierDehaene committed
369
370
371
            attention_mask[start_index:end_index, -batch.max_input_length :] = (
                batch.attention_mask[:, -batch.max_input_length :]
            )
372

OlivierDehaene's avatar
OlivierDehaene committed
373
            # Create padded tensor
374
            if decoder_input_ids is None:
375
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
376
                    (total_batch_size, 1),
377
                )
OlivierDehaene's avatar
OlivierDehaene committed
378
            # Copy to correct indices
379
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
380

OlivierDehaene's avatar
OlivierDehaene committed
381
            # Create padded tensor
382
            if decoder_attention_mask is None:
383
384
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
385
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
386
                )
OlivierDehaene's avatar
OlivierDehaene committed
387
388
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
389
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
390
391
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
392
393
                    start_index:end_index,
                    left_offset:-padding_right_offset,
394
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
395
            # If it exists, we need to index
396
            else:
397
398
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
399
400
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
401
                )
402
                decoder_attention_mask[
403
404
405
406
407
408
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
409

OlivierDehaene's avatar
OlivierDehaene committed
410
            # Create padded tensor
411
            if encoder_last_hidden_state is None:
412
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
413
414
415
416
417
418
419
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

Nicolas Patry's avatar
Nicolas Patry committed
420
421
422
423
424
425
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

OlivierDehaene's avatar
OlivierDehaene committed
426
            # Copy to correct indices
427
            encoder_last_hidden_state[
428
429
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
430
            batch.encoder_last_hidden_state = None
431

432
433
            # Ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
434
435
436
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
437

438
439
440
441
442
443
444
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
445

446
447
            start_index = end_index

448
449
450
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
451

452
453
454
455
456
457
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
458

459
460
461
462
463
464
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
465

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
485
486
487
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
507
508
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
509
                    del t
510

511
                    start_index = end_index
512
513
514
515

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
516
            requests_idx_mapping=requests_idx_mapping,
517
            input_ids=None,
518
519
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
520
            all_decoder_input_ids=all_decoder_input_ids,
521
522
523
524
525
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
526
527
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
528
529
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
530
531
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
532
533
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
534
            padding_right_offset=padding_right_offset,
535
            max_tokens=max_tokens,
536
537
        )

538
539
540
    def __len__(self):
        return len(self.requests)

541
542

class Seq2SeqLM(Model):
543
544
545
    def __init__(
        self,
        model_id: str,
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        default_dtype=torch.float16,
        trust_remote_code: bool = False,
        config_class=AutoConfig,
        tokenizer_class=AutoTokenizer,
        aliases=None,
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
        else:
            device = torch.device("cpu")
            dtype = torch.float32 if dtype is None else dtype

        config = config_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        config.quantize = quantize
        config.speculator = speculator

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        tokenizer.bos_token_id = config.decoder_start_token_id

590
591
592
        weights_loader = get_loader(
            quantize=quantize, model_id=model_id, revision=revision
        )
593
594
595
596
597
598
599
600
        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames,
            device=device,
            dtype=dtype,
            process_group=self.process_group,
            aliases=aliases,
601
            weights_loader=weights_loader,
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        )
        if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
            weights._set_gptq_params(model_id, revision)

        model = model_class(config, weights)

        torch.distributed.barrier(group=self.process_group)
        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
        )

    @classmethod
    def fallback(
        cls,
        model_id: str,
624
        revision: Optional[str] = None,
625
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
626
        speculator: Optional[str] = None,
627
        dtype: Optional[torch.dtype] = None,
628
        trust_remote_code: bool = False,
629
    ):
Nicolas Patry's avatar
Nicolas Patry committed
630
631
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
632

633
634
        if torch.cuda.is_available():
            device = torch.device("cuda")
635
            dtype = torch.float16 if dtype is None else dtype
636
        else:
637
638
639
            if quantize:
                raise ValueError("quantization is not available on CPU")

640
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
641
            dtype = torch.float32 if dtype is None else dtype
642

643
        model = AutoModelForSeq2SeqLM.from_pretrained(
644
            model_id,
645
            revision=revision,
646
            torch_dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
647
648
649
650
651
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
652
            load_in_8bit=quantize == "bitsandbytes",
653
            trust_remote_code=trust_remote_code,
654
        )
655
656
657
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

658
        tokenizer = AutoTokenizer.from_pretrained(
659
660
661
662
663
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
664
        )
665
        tokenizer.bos_token_id = model.config.decoder_start_token_id
666

667
668
669
670
671
        self = cls.__new__(
            cls,
        )
        super().__init__(
            self,
drbh's avatar
drbh committed
672
            model_id=model_id,
673
            model=model,
674
675
676
677
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
678
        )
679
        return self
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
695
        Optional[torch.Tensor],
696
697
698
699
700
701
702
703
704
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
705
            encoder_outputs=encoder_last_hidden_state,
706
707
708
            past_key_values=past_key_values,
            use_cache=True,
        )
709
710
711
712
713
714
        if isinstance(outputs, tuple):
            # Our custom models
            outputs, speculative_logits = outputs
        else:
            # Generic transformers models
            speculative_logits = None
715
716
        return (
            outputs.logits,
717
            speculative_logits,
718
719
720
721
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

722
    @tracer.start_as_current_span("generate_token")
723
724
    def generate_token(
        self, batch: Seq2SeqLMBatch
725
726
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
        start = time.time_ns()
727
728
729
730
731
732
733
734
735
736
737
738
739
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
740
            encoder_last_hidden_state = None
741

742
        logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
743
744
            batch.input_ids,
            batch.attention_mask,
745
            batch.decoder_input_ids,
746
747
            decoder_attention_mask,
            encoder_last_hidden_state,
748
            batch.past_key_values,
749
750
        )

Nicolas Patry's avatar
Nicolas Patry committed
751
752
        # Speculation is not active for seq2seq
        accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
753
754
755
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
756
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
757
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
758
759
        )

760
761
        start_decode = time.time_ns()

762
        # Finished requests
763
        generations: List[Generation] = []
764
        stopped = True
765
766
767
768
769

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
770
771
            batch.prefix_offsets,
            batch.read_offsets,
772
773
774
775
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
776
            batch.all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
777
778
779
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
780
781
782
783
784
785
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
786
787
            prefix_offset,
            read_offset,
788
789
790
791
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
792
            all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
793
794
795
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
796
797
        ) in enumerate(iterator):
            # Select next token
798
            next_token_id, logprobs = next_token_chooser(
799
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
800
            )
801
802

            # Append next token to decoder tokens
803
804
805
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
806
807
            new_decoder_input_length = decoder_input_length + 1

808
809
810
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
811
812
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
813
            )
814
815

            # Evaluate stopping criteria
816
817
            stop, reason = stopping_criteria(next_token_id, next_token_text)

818
            if not stop:
819
                stopped = False
820

821
822
823
824
825
826
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
827
828
                    output_text, _, _ = self.decode_token(
                        all_decoder_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
829
830
831
                        prefix_offset=len(all_decoder_input_ids)
                        - decoder_input_length
                        - 1,
832
                        read_offset=len(all_decoder_input_ids) - decoder_input_length,
OlivierDehaene's avatar
OlivierDehaene committed
833
                        skip_special_tokens=True,
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
849
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
Nicolas Patry's avatar
Nicolas Patry committed
850
                    prefill_tokens = Tokens(
851
852
853
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
OlivierDehaene's avatar
OlivierDehaene committed
854
                        [False],
855
856
857
858
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
859
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
860
                    all_top_tokens = []
OlivierDehaene's avatar
OlivierDehaene committed
861
                    for top_token_ids, top_token_logprobs in zip(
862
863
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
864
865
866
867
868
869
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
870
871
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
872
873
874
875
876
877
878
879
880
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
881
882
883
                else:
                    top_tokens = None

884
885
886
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
887
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
888
889
890
891
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
Nicolas Patry's avatar
Nicolas Patry committed
892
                    ),
893
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
894
                    top_tokens,
895
896
                )

897
                generations.append(generation)
898

899
            # Update values
drbh's avatar
drbh committed
900
901
902
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
903
904
905
906
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
907
908
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
909
910
911
912
913
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

914
        # We finished all generations in the batch; there is no next batch
915
        if stopped:
916
917
918
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
919

920
921
922
923
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
924
        # Update decoder_attention_mask as we added a new token to input_ids
925
926
927
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
928

929
930
931
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)