seq2seq_lm.py 34.3 KB
Newer Older
1
import torch
2
import torch.distributed
3
import time
4
from dataclasses import dataclass
5
from opentelemetry import trace
6
7
8
9
10
11
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    PreTrainedTokenizerBase,
    AutoConfig,
)
12
from typing import Optional, Tuple, List, Type, Dict
13
from text_generation_server.utils.import_utils import SYSTEM
14
15
16
17
18
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
Daniël de Kok's avatar
Daniël de Kok committed
19
from text_generation_server.utils.chunks import concat_text_chunks
20
from text_generation_server.utils.quantization import get_loader
21
from text_generation_server.utils.tokens import batch_top_tokens
22
23
24
25
26
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
Nicolas Patry's avatar
Nicolas Patry committed
27
    Tokens,
28
29
30
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
31

32
33
tracer = trace.get_tracer(__name__)

34
35

@dataclass
36
class Seq2SeqLMBatch(Batch):
37
38
    batch_id: int
    requests: List[generate_pb2.Request]
39
    requests_idx_mapping: Dict[int, int]
40

OlivierDehaene's avatar
OlivierDehaene committed
41
    # Encoder values
42
    input_ids: Optional[torch.Tensor]
43
44
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
45
    # Decoder values
46
47
48
49
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

50
51
52
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
53
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
54
55
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
56
    # Lengths of all generations present in the batch
57
58
    input_lengths: List[int]
    decoder_input_lengths: List[int]
59
60
    prefix_offsets: List[int]
    read_offsets: List[int]
61

OlivierDehaene's avatar
OlivierDehaene committed
62
    # Generation helpers
63
64
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
65
66
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
67

OlivierDehaene's avatar
OlivierDehaene committed
68
    # Metadata used for padding
69
70
    max_input_length: int
    max_decoder_input_length: int
71
    padding_right_offset: int
72

73
74
75
    # Maximum number of tokens this batch will grow to
    max_tokens: int

76
77
78
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
79
            id=self.batch_id,
80
            request_ids=[r.id for r in self.requests],
81
82
            size=len(self),
            max_tokens=self.max_tokens,
83
84
85
86
        )

    @classmethod
    def from_pb(
87
88
89
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
90
        dtype: torch.dtype,
91
        device: torch.device,
92
    ) -> "Seq2SeqLMBatch":
93
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
94
95
96
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
97
        top_n_tokens = []
98
        decoder_input_lengths = []
99
100
        prefix_offsets = []
        read_offsets = []
101
        requests_idx_mapping = {}
102
103

        # Parse batch
104
        max_truncation = 0
105
        padding_right_offset = 0
106
        max_decode_tokens = 0
107
        for i, r in enumerate(pb.requests):
Daniël de Kok's avatar
Daniël de Kok committed
108
            inputs.append(concat_text_chunks(r.input_chunks.chunks))
109
            requests_idx_mapping[r.id] = i
110
            decoder_input_lengths.append(1)
OlivierDehaene's avatar
OlivierDehaene committed
111
112
113
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
114
115
116
117
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
118
            top_n_tokens.append(r.top_n_tokens)
119
            max_truncation = max(max_truncation, r.truncate)
120
            max_decode_tokens += stopping_criteria.max_new_tokens
121
122
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
123
124
            )

OlivierDehaene's avatar
OlivierDehaene committed
125
        # Tokenize batch
126
        tokenized_inputs = tokenizer(
127
128
129
            inputs,
            return_tensors="pt",
            padding=True,
130
            return_token_type_ids=False,
131
132
            truncation=True,
            max_length=max_truncation,
133
        ).to(device)
134
135
136
137

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

138
139
140
141
142
143
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
144
145
146
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
147
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
Nicolas Patry's avatar
Nicolas Patry committed
148
149
150
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
151

152
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
153

154
155
156
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
157
            requests_idx_mapping=requests_idx_mapping,
158
159
160
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
161
            all_decoder_input_ids=list(all_decoder_input_ids),
162
163
164
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
165
            input_lengths=input_lengths.tolist(),
166
            decoder_input_lengths=decoder_input_lengths,
167
168
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
169
170
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
171
172
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
173
            max_input_length=max_input_length.item(),
174
            max_decoder_input_length=1,
175
            padding_right_offset=padding_right_offset,
176
            max_tokens=max_tokens,
177
178
        )

179
    @tracer.start_as_current_span("filter")
180
181
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
182
            raise ValueError("Batch must have at least one request")
183
        if len(request_ids) == len(self):
184
185
186
187
188
189
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
190
        requests = []
191
192
        input_lengths = []
        decoder_input_lengths = []
193
194
        prefix_offsets = []
        read_offsets = []
195
196
197
198
199

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
200
        top_n_tokens = []
201
202
203

        max_input_length = 0
        max_decoder_input_length = 0
204
        padding_right_offset = 0
205

206
        total_remaining_decode_tokens = 0
207

208
209
210
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
211
212
            keep_indices.append(idx)

213
            requests.append(self.requests[idx])
214
215
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
216
217
218
219
220
221
222
223
224
225
226
227
228
229

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
230
231
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
232
            top_n_tokens.append(self.top_n_tokens[idx])
233
            remaining_decode_tokens = (
234
235
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
236
237
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
238
239

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
240
241
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
242
        if self.decoder_attention_mask is not None:
243
            self.decoder_attention_mask = self.decoder_attention_mask[
244
245
246
247
248
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
249
250
            ]

251
252
253
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
254
255

        # Ensure that past_key_values tensors can be updated in-place
256
        if type(self.past_key_values[0]) is tuple:
257
258
259
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
260
261
262
263
264
265
266
267

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

Nicolas Patry's avatar
Nicolas Patry committed
268
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
269
        max_tokens = (
270
            len(request_ids) * (max_input_length + max_decoder_input_length)
271
272
273
            + remaining_decode_tokens
        )

274
275
276
277
278
279
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
280
281
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
282
283
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
284
285
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
286
287
288
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
289
        self.max_tokens = max_tokens
290
291

        return self
292

293
    @classmethod
294
    @tracer.start_as_current_span("concatenate")
295
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
296
297
        """Concatenate multiple batches together by padding internal torch tensors"""

298
        # Used for padding
299
300
301
302
303
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
304
            total_batch_size += len(batch)
305
306
307
308
309
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
310
311
312

        # Batch attributes
        requests = []
313
314
        requests_idx_mapping = {}
        all_decoder_input_ids = []
315
316
        input_lengths = []
        decoder_input_lengths = []
317
318
        prefix_offsets = []
        read_offsets = []
319
320
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
321
        top_n_tokens = []
322
        max_tokens = 0
323

OlivierDehaene's avatar
OlivierDehaene committed
324
        # Batch tensors
325
326
327
328
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
Nicolas Patry's avatar
Nicolas Patry committed
329
        top_n_tokens_tensor = None
330
331
332
333
334
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
335

336
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
337
            # Extend all list attributes
338
            requests.extend(batch.requests)
339
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
340
341
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
342
343
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
344
345
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
346
            top_n_tokens.extend(batch.top_n_tokens)
347

348
349
350
351
352
353
354
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

355
            # Slicing end index for this batch
356
            end_index = start_index + len(batch)
357
358
359
360
361

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
362
            # Create padded tensor
363
            if attention_mask is None:
364
                attention_mask = batch.attention_mask.new_zeros(
365
366
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
367
            # Copy to correct indices
OlivierDehaene's avatar
OlivierDehaene committed
368
369
370
            attention_mask[start_index:end_index, -batch.max_input_length :] = (
                batch.attention_mask[:, -batch.max_input_length :]
            )
371

OlivierDehaene's avatar
OlivierDehaene committed
372
            # Create padded tensor
373
            if decoder_input_ids is None:
374
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
375
                    (total_batch_size, 1),
376
                )
OlivierDehaene's avatar
OlivierDehaene committed
377
            # Copy to correct indices
378
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
379

OlivierDehaene's avatar
OlivierDehaene committed
380
            # Create padded tensor
381
            if decoder_attention_mask is None:
382
383
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
384
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
385
                )
OlivierDehaene's avatar
OlivierDehaene committed
386
387
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
388
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
389
390
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
391
392
                    start_index:end_index,
                    left_offset:-padding_right_offset,
393
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
394
            # If it exists, we need to index
395
            else:
396
397
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
398
399
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
400
                )
401
                decoder_attention_mask[
402
403
404
405
406
407
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
408

OlivierDehaene's avatar
OlivierDehaene committed
409
            # Create padded tensor
410
            if encoder_last_hidden_state is None:
411
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
412
413
414
415
416
417
418
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

Nicolas Patry's avatar
Nicolas Patry committed
419
420
421
422
423
424
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

OlivierDehaene's avatar
OlivierDehaene committed
425
            # Copy to correct indices
426
            encoder_last_hidden_state[
427
428
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
429
            batch.encoder_last_hidden_state = None
430

431
            # Ensure that we can update tensors in-place
432
            if isinstance(batch.past_key_values[0], tuple):
433
434
435
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
436

437
438
439
440
441
442
443
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
444

445
446
            start_index = end_index

447
448
449
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
450

451
452
453
454
455
456
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
457

458
459
460
461
462
463
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
484
485
486
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
506
507
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
508
                    del t
509

510
                    start_index = end_index
511
512
513
514

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
515
            requests_idx_mapping=requests_idx_mapping,
516
            input_ids=None,
517
518
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
519
            all_decoder_input_ids=all_decoder_input_ids,
520
521
522
523
524
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
525
526
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
527
528
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
529
530
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
531
532
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
533
            padding_right_offset=padding_right_offset,
534
            max_tokens=max_tokens,
535
536
        )

537
538
539
    def __len__(self):
        return len(self.requests)

540
541

class Seq2SeqLM(Model):
542
543
544
    def __init__(
        self,
        model_id: str,
545
546
547
548
549
550
551
552
553
554
555
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        default_dtype=torch.float16,
        trust_remote_code: bool = False,
        config_class=AutoConfig,
        tokenizer_class=AutoTokenizer,
        aliases=None,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
556
        self.quantize = quantize
557
558
559
560
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
561
562
563
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            device = torch.device(f"xpu:{rank}")
            dtype = default_dtype if dtype is None else dtype
564
        elif SYSTEM == "ipex":
565
566
567
            device = torch.device("cpu")
            # Float16 doesn't exist on target.
            dtype = torch.bfloat16 if dtype is None else dtype
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        else:
            device = torch.device("cpu")
            dtype = torch.float32 if dtype is None else dtype

        config = config_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        config.quantize = quantize
        config.speculator = speculator

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        tokenizer.bos_token_id = config.decoder_start_token_id

589
590
591
        weights_loader = get_loader(
            quantize=quantize, model_id=model_id, revision=revision
        )
592
593
594
595
596
597
598
599
        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames,
            device=device,
            dtype=dtype,
            process_group=self.process_group,
            aliases=aliases,
600
            weights_loader=weights_loader,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        )
        if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
            weights._set_gptq_params(model_id, revision)

        model = model_class(config, weights)

        torch.distributed.barrier(group=self.process_group)
        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
        )

    @classmethod
    def fallback(
        cls,
        model_id: str,
623
        revision: Optional[str] = None,
624
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
625
        speculator: Optional[str] = None,
626
        dtype: Optional[torch.dtype] = None,
627
        trust_remote_code: bool = False,
628
    ):
Nicolas Patry's avatar
Nicolas Patry committed
629
630
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
631

632
        device_count = 0
633
634
        if torch.cuda.is_available():
            device = torch.device("cuda")
635
636
637
638
639
            device_count = torch.cuda.device_count()
            dtype = torch.float16 if dtype is None else dtype
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            device = torch.device("xpu")
            device_count = torch.xpu.device_count()
640
            dtype = torch.float16 if dtype is None else dtype
641
        else:
642
643
644
            if quantize:
                raise ValueError("quantization is not available on CPU")

645
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
646
            dtype = torch.float32 if dtype is None else dtype
647

648
        model = AutoModelForSeq2SeqLM.from_pretrained(
649
            model_id,
650
            revision=revision,
651
            torch_dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
652
653
            device_map=(
                "auto"
654
                if device_count > 1
OlivierDehaene's avatar
OlivierDehaene committed
655
656
                else None
            ),
657
            load_in_8bit=quantize == "bitsandbytes",
658
            trust_remote_code=trust_remote_code,
659
        )
660
661
        if device_count == 1:
            model = model.to(device)
662

663
        tokenizer = AutoTokenizer.from_pretrained(
664
665
666
667
668
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
669
        )
670
        tokenizer.bos_token_id = model.config.decoder_start_token_id
671

672
673
674
675
676
        self = cls.__new__(
            cls,
        )
        super().__init__(
            self,
drbh's avatar
drbh committed
677
            model_id=model_id,
678
            model=model,
679
680
681
682
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
683
        )
684
        self.quantize = quantize
685
        return self
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
701
        Optional[torch.Tensor],
702
703
704
705
706
707
708
709
710
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
711
            encoder_outputs=encoder_last_hidden_state,
712
713
714
            past_key_values=past_key_values,
            use_cache=True,
        )
715
716
717
718
719
720
        if isinstance(outputs, tuple):
            # Our custom models
            outputs, speculative_logits = outputs
        else:
            # Generic transformers models
            speculative_logits = None
721
722
        return (
            outputs.logits,
723
            speculative_logits,
724
725
726
727
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

728
    @tracer.start_as_current_span("generate_token")
729
730
    def generate_token(
        self, batch: Seq2SeqLMBatch
731
732
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
        start = time.time_ns()
733
734
735
736
737
738
739
740
741
742
743
744
745
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
746
            encoder_last_hidden_state = None
747

748
        logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
749
750
            batch.input_ids,
            batch.attention_mask,
751
            batch.decoder_input_ids,
752
753
            decoder_attention_mask,
            encoder_last_hidden_state,
754
            batch.past_key_values,
755
756
        )

Nicolas Patry's avatar
Nicolas Patry committed
757
758
        # Speculation is not active for seq2seq
        accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
759
760
761
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
762
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
763
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
764
765
        )

766
767
        start_decode = time.time_ns()

768
        # Finished requests
769
        generations: List[Generation] = []
770
        stopped = True
771
772
773
774
775

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
776
777
            batch.prefix_offsets,
            batch.read_offsets,
778
779
780
781
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
782
            batch.all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
783
784
785
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
786
787
788
789
790
791
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
792
793
            prefix_offset,
            read_offset,
794
795
796
797
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
798
            all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
799
800
801
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
802
803
        ) in enumerate(iterator):
            # Select next token
804
            next_token_id, logprobs = next_token_chooser(
805
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
806
            )
807
808

            # Append next token to decoder tokens
809
810
811
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
812
813
            new_decoder_input_length = decoder_input_length + 1

814
815
816
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
817
818
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
819
            )
820
821

            # Evaluate stopping criteria
822
823
            stop, reason = stopping_criteria(next_token_id, next_token_text)

824
            if not stop:
825
                stopped = False
826

827
828
829
830
831
832
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
833
834
                    output_text, _, _ = self.decode_token(
                        all_decoder_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
835
836
837
                        prefix_offset=len(all_decoder_input_ids)
                        - decoder_input_length
                        - 1,
838
                        read_offset=len(all_decoder_input_ids) - decoder_input_length,
OlivierDehaene's avatar
OlivierDehaene committed
839
                        skip_special_tokens=True,
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
855
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
Nicolas Patry's avatar
Nicolas Patry committed
856
                    prefill_tokens = Tokens(
857
858
859
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
OlivierDehaene's avatar
OlivierDehaene committed
860
                        [False],
861
862
863
864
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
865
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
866
                    all_top_tokens = []
OlivierDehaene's avatar
OlivierDehaene committed
867
                    for top_token_ids, top_token_logprobs in zip(
868
869
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
870
871
872
873
874
875
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
876
877
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
878
879
880
881
882
883
884
885
886
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
887
888
889
                else:
                    top_tokens = None

890
891
892
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
893
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
894
895
896
897
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
Nicolas Patry's avatar
Nicolas Patry committed
898
                    ),
899
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
900
                    top_tokens,
901
902
                )

903
                generations.append(generation)
904

905
            # Update values
drbh's avatar
drbh committed
906
907
908
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
909
910
911
912
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
913
914
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
915
916
917
918
919
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

920
        # We finished all generations in the batch; there is no next batch
921
        if stopped:
922
923
924
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
925

926
927
928
929
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
930
        # Update decoder_attention_mask as we added a new token to input_ids
931
932
933
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
934

935
936
937
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)