seq2seq_lm.py 34 KB
Newer Older
1
import torch
2
import torch.distributed
3
import time
4
from dataclasses import dataclass
5
from opentelemetry import trace
6
7
8
9
10
11
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    PreTrainedTokenizerBase,
    AutoConfig,
)
12
from typing import Optional, Tuple, List, Type, Dict
13
from text_generation_server.utils.import_utils import SYSTEM
14
15
16
17
18
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
Daniël de Kok's avatar
Daniël de Kok committed
19
from text_generation_server.utils.chunks import concat_text_chunks
20
from text_generation_server.utils.quantization import get_loader
21
from text_generation_server.utils.tokens import batch_top_tokens
22
23
24
25
26
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
Nicolas Patry's avatar
Nicolas Patry committed
27
    Tokens,
28
29
30
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
31

32
33
tracer = trace.get_tracer(__name__)

34
35

@dataclass
36
class Seq2SeqLMBatch(Batch):
37
38
    batch_id: int
    requests: List[generate_pb2.Request]
39
    requests_idx_mapping: Dict[int, int]
40

OlivierDehaene's avatar
OlivierDehaene committed
41
    # Encoder values
42
    input_ids: Optional[torch.Tensor]
43
44
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
45
    # Decoder values
46
47
48
49
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

50
51
52
    # All tokens
    all_decoder_input_ids: List[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
53
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
54
55
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
56
    # Lengths of all generations present in the batch
57
58
    input_lengths: List[int]
    decoder_input_lengths: List[int]
59
60
    prefix_offsets: List[int]
    read_offsets: List[int]
61

OlivierDehaene's avatar
OlivierDehaene committed
62
    # Generation helpers
63
64
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
65
66
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
67

OlivierDehaene's avatar
OlivierDehaene committed
68
    # Metadata used for padding
69
70
    max_input_length: int
    max_decoder_input_length: int
71
    padding_right_offset: int
72

73
74
75
    # Maximum number of tokens this batch will grow to
    max_tokens: int

76
77
78
    def to_pb(self) -> generate_pb2.CachedBatch:
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
        return generate_pb2.CachedBatch(
79
            id=self.batch_id,
80
            request_ids=[r.id for r in self.requests],
81
82
            size=len(self),
            max_tokens=self.max_tokens,
83
84
85
86
        )

    @classmethod
    def from_pb(
87
88
89
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
90
        dtype: torch.dtype,
91
        device: torch.device,
92
    ) -> "Seq2SeqLMBatch":
93
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
94
95
96
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
97
        top_n_tokens = []
98
        decoder_input_lengths = []
99
100
        prefix_offsets = []
        read_offsets = []
101
        requests_idx_mapping = {}
102
103

        # Parse batch
104
        max_truncation = 0
105
        padding_right_offset = 0
106
        max_decode_tokens = 0
107
        for i, r in enumerate(pb.requests):
Daniël de Kok's avatar
Daniël de Kok committed
108
            inputs.append(concat_text_chunks(r.input_chunks.chunks))
109
            requests_idx_mapping[r.id] = i
110
            decoder_input_lengths.append(1)
OlivierDehaene's avatar
OlivierDehaene committed
111
112
113
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
114
115
116
117
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
118
            top_n_tokens.append(r.top_n_tokens)
119
            max_truncation = max(max_truncation, r.truncate)
120
            max_decode_tokens += stopping_criteria.max_new_tokens
121
122
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
123
124
            )

OlivierDehaene's avatar
OlivierDehaene committed
125
        # Tokenize batch
126
        tokenized_inputs = tokenizer(
127
128
129
            inputs,
            return_tensors="pt",
            padding=True,
130
            return_token_type_ids=False,
131
132
            truncation=True,
            max_length=max_truncation,
133
        ).to(device)
134
135
136
137

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

138
139
140
141
142
143
        # Decoder sequence only contains the bos_token
        decoder_input_ids = (
            torch.tensor(tokenizer.bos_token_id, device=device)
            .repeat(len(pb.requests))
            .view(-1, 1)
        )
144
145
146
        for _ in pb.requests:
            prefix_offsets.append(0)
            read_offsets.append(1)
147
        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
Nicolas Patry's avatar
Nicolas Patry committed
148
149
150
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
151

152
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
153

154
155
156
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
157
            requests_idx_mapping=requests_idx_mapping,
158
159
160
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
161
            all_decoder_input_ids=list(all_decoder_input_ids),
162
163
164
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
165
            input_lengths=input_lengths.tolist(),
166
            decoder_input_lengths=decoder_input_lengths,
167
168
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
169
170
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
171
172
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
173
            max_input_length=max_input_length.item(),
174
            max_decoder_input_length=1,
175
            padding_right_offset=padding_right_offset,
176
            max_tokens=max_tokens,
177
178
        )

179
    @tracer.start_as_current_span("filter")
180
181
    def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
        if len(request_ids) == 0:
182
            raise ValueError("Batch must have at least one request")
183
        if len(request_ids) == len(self):
184
185
186
187
188
189
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
190
        requests = []
191
192
        input_lengths = []
        decoder_input_lengths = []
193
194
        prefix_offsets = []
        read_offsets = []
195
196
197
198
199

        all_decoder_input_ids = []

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
200
        top_n_tokens = []
201
202
203

        max_input_length = 0
        max_decoder_input_length = 0
204
        padding_right_offset = 0
205

206
        total_remaining_decode_tokens = 0
207

208
209
210
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
211
212
            keep_indices.append(idx)

213
            requests.append(self.requests[idx])
214
215
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
216
217
218
219
220
221
222
223
224
225
226
227
228
229

            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            request_decoder_input_length = self.decoder_input_lengths[idx]
            decoder_input_lengths.append(request_decoder_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, request_decoder_input_length
            )

            next_token_choosers.append(self.next_token_choosers[idx])
230
231
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
232
            top_n_tokens.append(self.top_n_tokens[idx])
233
            remaining_decode_tokens = (
234
235
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
236
237
            total_remaining_decode_tokens += remaining_decode_tokens
            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
238
239

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
240
241
        self.decoder_input_ids = self.decoder_input_ids[keep_indices]
        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
242
        if self.decoder_attention_mask is not None:
243
            self.decoder_attention_mask = self.decoder_attention_mask[
244
245
246
247
248
                keep_indices,
                -(self.padding_right_offset + max_decoder_input_length) : (
                    self.decoder_attention_mask.shape[1] - self.padding_right_offset
                )
                + padding_right_offset,
249
250
            ]

251
252
253
        self.encoder_last_hidden_state = self.encoder_last_hidden_state[
            keep_indices, -max_input_length:
        ]
254
255

        # Ensure that past_key_values tensors can be updated in-place
256
        if type(self.past_key_values[0]) is tuple:
257
258
259
            self.past_key_values = [
                [t for t in layer] for layer in self.past_key_values
            ]
260
261
262
263
264
265
266
267

        decoder_past_seq_len = max_decoder_input_length - 1
        for layer in self.past_key_values:
            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
            layer[2] = layer[2][keep_indices, :, -max_input_length:]
            layer[3] = layer[3][keep_indices, :, -max_input_length:]

Nicolas Patry's avatar
Nicolas Patry committed
268
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
269
        max_tokens = (
270
            len(request_ids) * (max_input_length + max_decoder_input_length)
271
272
273
            + remaining_decode_tokens
        )

274
275
276
277
278
279
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = None
        self.all_decoder_input_ids = all_decoder_input_ids
        self.input_lengths = input_lengths
        self.decoder_input_lengths = decoder_input_lengths
280
281
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
282
283
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
284
285
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
286
287
288
        self.max_input_length = max_input_length
        self.max_decoder_input_length = max_decoder_input_length
        self.padding_right_offset = padding_right_offset
289
        self.max_tokens = max_tokens
290
291

        return self
292

293
    @classmethod
294
    @tracer.start_as_current_span("concatenate")
295
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
296
297
        """Concatenate multiple batches together by padding internal torch tensors"""

298
        # Used for padding
299
300
301
302
303
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
304
            total_batch_size += len(batch)
305
306
307
308
309
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
310
311
312

        # Batch attributes
        requests = []
313
314
        requests_idx_mapping = {}
        all_decoder_input_ids = []
315
316
        input_lengths = []
        decoder_input_lengths = []
317
318
        prefix_offsets = []
        read_offsets = []
319
320
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
321
        top_n_tokens = []
322
        max_tokens = 0
323

OlivierDehaene's avatar
OlivierDehaene committed
324
        # Batch tensors
325
326
327
328
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
Nicolas Patry's avatar
Nicolas Patry committed
329
        top_n_tokens_tensor = None
330
331
332
333
334
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
335

336
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
337
            # Extend all list attributes
338
            requests.extend(batch.requests)
339
            all_decoder_input_ids.extend(batch.all_decoder_input_ids)
340
341
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
342
343
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
344
345
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
346
            top_n_tokens.extend(batch.top_n_tokens)
347

348
349
350
351
352
353
354
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

355
            # Slicing end index for this batch
356
            end_index = start_index + len(batch)
357
358
359
360
361

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
362
            # Create padded tensor
363
            if attention_mask is None:
364
                attention_mask = batch.attention_mask.new_zeros(
365
366
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
367
            # Copy to correct indices
OlivierDehaene's avatar
OlivierDehaene committed
368
369
370
            attention_mask[start_index:end_index, -batch.max_input_length :] = (
                batch.attention_mask[:, -batch.max_input_length :]
            )
371

OlivierDehaene's avatar
OlivierDehaene committed
372
            # Create padded tensor
373
            if decoder_input_ids is None:
374
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
375
                    (total_batch_size, 1),
376
                )
OlivierDehaene's avatar
OlivierDehaene committed
377
            # Copy to correct indices
378
            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
379

OlivierDehaene's avatar
OlivierDehaene committed
380
            # Create padded tensor
381
            if decoder_attention_mask is None:
382
383
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
384
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
385
                )
OlivierDehaene's avatar
OlivierDehaene committed
386
387
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
388
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
389
390
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
391
392
                    start_index:end_index,
                    left_offset:-padding_right_offset,
393
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
394
            # If it exists, we need to index
395
            else:
396
397
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
398
399
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
400
                )
401
                decoder_attention_mask[
402
403
404
405
406
407
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
408

OlivierDehaene's avatar
OlivierDehaene committed
409
            # Create padded tensor
410
            if encoder_last_hidden_state is None:
411
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
412
413
414
415
416
417
418
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

Nicolas Patry's avatar
Nicolas Patry committed
419
420
421
422
423
424
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

OlivierDehaene's avatar
OlivierDehaene committed
425
            # Copy to correct indices
426
            encoder_last_hidden_state[
427
428
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
429
            batch.encoder_last_hidden_state = None
430

431
            # Ensure that we can update tensors in-place
432
            if isinstance(batch.past_key_values[0], tuple):
433
434
435
                batch.past_key_values = [
                    [t for t in layer] for layer in batch.past_key_values
                ]
436

437
438
439
440
441
442
443
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length
                - batch.max_input_length
                + max_decoder_input_length
                - batch.max_decoder_input_length
            ) * len(batch)
444

445
446
            start_index = end_index

447
448
449
        # Determine shapes for new past kv tensors
        first_past_kvs = batches[0].past_key_values
        _, num_heads, _, head_dim = first_past_kvs[0][0].shape
450

451
452
453
454
455
456
        padded_dec_t_shape = (
            total_batch_size,
            num_heads,
            (max_decoder_input_length - 1),
            head_dim,
        )
457

458
459
460
461
462
463
        padded_enc_t_shape = (
            total_batch_size,
            num_heads,
            max_input_length,
            head_dim,
        )
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        # Iterate over attention layers
        for j in range(len(first_past_kvs)):
            past_key_values.append([])

            # Decoder past
            for k in range(0, 2):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    past_seq_len = batch.max_decoder_input_length - 1
484
485
486
                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
                        :, :, -past_seq_len:, :
                    ]
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
                    del t

                    start_index = end_index

            # Encoder past
            for k in range(2, 4):
                # Initialize tensors
                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
                past_key_values[j].append(padded_past_values)

                start_index = 0
                for batch in batches:
                    t = batch.past_key_values[j][k]
                    # Clear reference to the original tensor
                    batch.past_key_values[j][k] = None
                    # Slicing end index for this batch
                    end_index = start_index + len(batch)
                    # We slice the past keys and values to remove the padding from previous batches
                    padded_past_values[
506
507
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]
508
                    del t
509

510
                    start_index = end_index
511
512
513
514

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
515
            requests_idx_mapping=requests_idx_mapping,
516
            input_ids=None,
517
518
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
519
            all_decoder_input_ids=all_decoder_input_ids,
520
521
522
523
524
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
525
526
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
527
528
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
529
530
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
531
532
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
533
            padding_right_offset=padding_right_offset,
534
            max_tokens=max_tokens,
535
536
        )

537
538
539
    def __len__(self):
        return len(self.requests)

540
541

class Seq2SeqLM(Model):
542
543
544
    def __init__(
        self,
        model_id: str,
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        default_dtype=torch.float16,
        trust_remote_code: bool = False,
        config_class=AutoConfig,
        tokenizer_class=AutoTokenizer,
        aliases=None,
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
        else:
            device = torch.device("cpu")
            dtype = torch.float32 if dtype is None else dtype

        config = config_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        config.quantize = quantize
        config.speculator = speculator

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        tokenizer.bos_token_id = config.decoder_start_token_id

589
590
591
        weights_loader = get_loader(
            quantize=quantize, model_id=model_id, revision=revision
        )
592
593
594
595
596
597
598
599
        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames,
            device=device,
            dtype=dtype,
            process_group=self.process_group,
            aliases=aliases,
600
            weights_loader=weights_loader,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        )
        if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
            weights._set_gptq_params(model_id, revision)

        model = model_class(config, weights)

        torch.distributed.barrier(group=self.process_group)
        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
        )

    @classmethod
    def fallback(
        cls,
        model_id: str,
623
        revision: Optional[str] = None,
624
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
625
        speculator: Optional[str] = None,
626
        dtype: Optional[torch.dtype] = None,
627
        trust_remote_code: bool = False,
628
    ):
Nicolas Patry's avatar
Nicolas Patry committed
629
630
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
631

632
633
        if torch.cuda.is_available():
            device = torch.device("cuda")
634
            dtype = torch.float16 if dtype is None else dtype
635
        else:
636
637
638
            if quantize:
                raise ValueError("quantization is not available on CPU")

639
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
640
            dtype = torch.float32 if dtype is None else dtype
641

642
        model = AutoModelForSeq2SeqLM.from_pretrained(
643
            model_id,
644
            revision=revision,
645
            torch_dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
646
647
648
649
650
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
651
            load_in_8bit=quantize == "bitsandbytes",
652
            trust_remote_code=trust_remote_code,
653
        )
654
655
656
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

657
        tokenizer = AutoTokenizer.from_pretrained(
658
659
660
661
662
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
663
        )
664
        tokenizer.bos_token_id = model.config.decoder_start_token_id
665

666
667
668
669
670
        self = cls.__new__(
            cls,
        )
        super().__init__(
            self,
drbh's avatar
drbh committed
671
            model_id=model_id,
672
            model=model,
673
674
675
676
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
677
        )
678
        return self
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
694
        Optional[torch.Tensor],
695
696
697
698
699
700
701
702
703
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
704
            encoder_outputs=encoder_last_hidden_state,
705
706
707
            past_key_values=past_key_values,
            use_cache=True,
        )
708
709
710
711
712
713
        if isinstance(outputs, tuple):
            # Our custom models
            outputs, speculative_logits = outputs
        else:
            # Generic transformers models
            speculative_logits = None
714
715
        return (
            outputs.logits,
716
            speculative_logits,
717
718
719
720
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

721
    @tracer.start_as_current_span("generate_token")
722
723
    def generate_token(
        self, batch: Seq2SeqLMBatch
724
725
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
        start = time.time_ns()
726
727
728
729
730
731
732
733
734
735
736
737
738
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
739
            encoder_last_hidden_state = None
740

741
        logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
742
743
            batch.input_ids,
            batch.attention_mask,
744
            batch.decoder_input_ids,
745
746
            decoder_attention_mask,
            encoder_last_hidden_state,
747
            batch.past_key_values,
748
749
        )

Nicolas Patry's avatar
Nicolas Patry committed
750
751
        # Speculation is not active for seq2seq
        accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
752
753
754
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
755
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
756
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
757
758
        )

759
760
        start_decode = time.time_ns()

761
        # Finished requests
762
        generations: List[Generation] = []
763
        stopped = True
764
765
766
767
768

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
769
770
            batch.prefix_offsets,
            batch.read_offsets,
771
772
773
774
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
775
            batch.all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
776
777
778
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
779
780
781
782
783
784
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
785
786
            prefix_offset,
            read_offset,
787
788
789
790
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
791
            all_decoder_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
792
793
794
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
795
796
        ) in enumerate(iterator):
            # Select next token
797
            next_token_id, logprobs = next_token_chooser(
798
                all_decoder_input_ids.view(1, -1), logits[-1:, :]
799
            )
800
801

            # Append next token to decoder tokens
802
803
804
            all_decoder_input_ids = torch.cat(
                [all_decoder_input_ids, next_token_id.squeeze(1)]
            )
OlivierDehaene's avatar
OlivierDehaene committed
805
806
            new_decoder_input_length = decoder_input_length + 1

807
808
809
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
810
811
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_decoder_input_ids, prefix_offset, read_offset
812
            )
813
814

            # Evaluate stopping criteria
815
816
            stop, reason = stopping_criteria(next_token_id, next_token_text)

817
            if not stop:
818
                stopped = False
819

820
821
822
823
824
825
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Slice with decoder_input_length to remove padding
                    # Decode all tokens
826
827
                    output_text, _, _ = self.decode_token(
                        all_decoder_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
828
829
830
                        prefix_offset=len(all_decoder_input_ids)
                        - decoder_input_length
                        - 1,
831
                        read_offset=len(all_decoder_input_ids) - decoder_input_length,
OlivierDehaene's avatar
OlivierDehaene committed
832
                        skip_special_tokens=True,
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
                    )

                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
848
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
Nicolas Patry's avatar
Nicolas Patry committed
849
                    prefill_tokens = Tokens(
850
851
852
                        [self.tokenizer.bos_token_id],
                        [float("nan")],
                        [self.tokenizer.bos_token],
OlivierDehaene's avatar
OlivierDehaene committed
853
                        [False],
854
855
856
857
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
858
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
859
                    all_top_tokens = []
OlivierDehaene's avatar
OlivierDehaene committed
860
                    for top_token_ids, top_token_logprobs in zip(
861
862
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
863
864
865
866
867
868
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
869
870
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
871
872
873
874
875
876
877
878
879
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
880
881
882
                else:
                    top_tokens = None

883
884
885
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
886
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
887
888
889
890
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
Nicolas Patry's avatar
Nicolas Patry committed
891
                    ),
892
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
893
                    top_tokens,
894
895
                )

896
                generations.append(generation)
897

898
            # Update values
drbh's avatar
drbh committed
899
900
901
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
902
903
904
905
            batch.decoder_input_ids[i] = next_token_id
            batch.all_decoder_input_ids[i] = all_decoder_input_ids
            batch.input_lengths[i] = input_length
            batch.decoder_input_lengths[i] = new_decoder_input_length
906
907
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
908
909
910
911
912
            batch.max_input_length = max(batch.max_input_length, input_length)
            batch.max_decoder_input_length = max(
                batch.max_decoder_input_length, new_decoder_input_length
            )

913
        # We finished all generations in the batch; there is no next batch
914
        if stopped:
915
916
917
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
918

919
920
921
922
        # We don't need input_ids after the prefill forward
        batch.input_ids = None
        batch.encoder_last_hidden_state = encoder_last_hidden_state
        batch.past_key_values = past
923
        # Update decoder_attention_mask as we added a new token to input_ids
924
925
926
        if batch.decoder_attention_mask is not None:
            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
        batch.padding_right_offset -= 1
927

928
929
930
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)