seq2seq_lm.py 21.9 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
7
from typing import Optional, Tuple, List, Type

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
25
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
26
    # Encoder values
27
28
29
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Decoder values
31
32
33
34
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
35
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
36
37
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
38
    # Lengths of all generations present in the batch
39
40
41
    input_lengths: List[int]
    decoder_input_lengths: List[int]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Generation helpers
43
44
45
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
46
    # Metadata used for padding
47
48
49
    size: int
    max_input_length: int
    max_decoder_input_length: int
50
    padding_right_offset: int
51

52
    def to_pb(self) -> generate_pb2.Batch:
53
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
54
55
56
57
58
59
60
61
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
62
63
64
65
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
66
    ) -> "Seq2SeqLMBatch":
67
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
68
69
70
71
72
73
74
75
76
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
        input_lengths = []

        decoder_input_ids = []
        decoder_input_lengths = []

        # Parse batch
77
78
        max_input_length = 0
        padding_right_offset = 0
79
80
81
        for r in pb.requests:
            inputs.append(r.inputs)
            input_lengths.append(r.input_length)
OlivierDehaene's avatar
OlivierDehaene committed
82
            # Decoder sequence only contains the bos_token
83
84
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
85
            next_token_choosers.append(
86
                NextTokenChooser.from_pb(r.parameters, device)
87
            )
88
89
90
91
92
93
94
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            max_input_length = max(max_input_length, r.input_length)
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
95
96
            )

OlivierDehaene's avatar
OlivierDehaene committed
97
        # Tokenize batch
98
        tokenized_inputs = tokenizer(
99
100
101
            inputs,
            return_tensors="pt",
            padding=True,
102
            return_token_type_ids=False,
103
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
104
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
105
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
            max_input_length=max(input_lengths),
            max_decoder_input_length=1,
123
            padding_right_offset=padding_right_offset,
124
125
126
        )

    @classmethod
127
    @tracer.start_as_current_span("concatenate")
128
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
129
130
        """Concatenate multiple batches together by padding internal torch tensors"""

131
        # Used for padding
132
133
134
135
136
137
138
139
140
141
142
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
143
144
145
146
147
148
149
150

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
151
        # Batch tensors
152
153
154
155
156
157
158
159
160
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
161

162
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
163
            # Extend all list attributes
164
165
166
167
168
169
170
171
172
173
174
175
176
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
177
            # Create padded tensor
178
            if attention_mask is None:
179
                attention_mask = batch.attention_mask.new_zeros(
180
181
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
182
            # Copy to correct indices
183
184
185
186
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
187
            # Create padded tensor
188
            if decoder_input_ids is None:
189
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
190
191
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
192
            # Copy to correct indices
193
194
195
196
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
197
            # Create padded tensor
198
            if decoder_attention_mask is None:
199
200
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
201
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
202
                )
OlivierDehaene's avatar
OlivierDehaene committed
203
204
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
205
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
206
207
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
208
209
                    start_index:end_index,
                    left_offset:-padding_right_offset,
210
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
211
            # If it exists, we need to index
212
            else:
213
214
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
215
216
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
217
                )
218
                decoder_attention_mask[
219
220
221
222
223
224
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
225

OlivierDehaene's avatar
OlivierDehaene committed
226
            # Create padded tensor
227
            if encoder_last_hidden_state is None:
228
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
229
230
231
232
233
234
235
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
236
            # Copy to correct indices
237
            encoder_last_hidden_state[
238
239
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
240

OlivierDehaene's avatar
OlivierDehaene committed
241
            # Iterate over attention layers
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
261
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
285
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
286
287
288
289
290
291
292
293
294
295

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
296
            input_ids=None,
297
298
299
300
301
302
303
304
305
306
307
308
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
309
            padding_right_offset=padding_right_offset,
310
311
        )

312
313
314
    def __len__(self):
        return len(self.requests)

315
316

class Seq2SeqLM(Model):
317
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
318
319
320
321
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
322
323
324
            if quantize:
                raise ValueError("quantization is not available on CPU")

325
326
327
328
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
329
            model_id,
330
            revision=revision,
331
332
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
333
            load_in_8bit=quantize,
334
        ).eval()
335
        tokenizer = AutoTokenizer.from_pretrained(
336
            model_id, revision=revision, padding_side="left"
337
        )
338
339
340
341
342
343
344
345
346
347
348
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

349
    def decode(self, decoder_ids: List[int]) -> str:
350
351
352
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
373
            encoder_outputs=encoder_last_hidden_state,
374
375
376
377
378
379
380
381
382
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

383
    @tracer.start_as_current_span("generate_token")
384
385
    def generate_token(
        self, batch: Seq2SeqLMBatch
386
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # check if first forward or not
        if batch.past_key_values is not None:
            # Only take the last token
            decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
        else:
            decoder_input_ids = batch.decoder_input_ids

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
            encoder_last_hidden_state = batch.encoder_last_hidden_state

409
410
411
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
412
413
414
            decoder_input_ids,
            decoder_attention_mask,
            encoder_last_hidden_state,
415
            batch.past_key_values,
416
417
418
419
420
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
421
        # New values for next forward
422
423
424
425
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
426
        # Metadata
427
428
429
430
431
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
432
        generations: List[Generation] = []
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
453
            decoder_input_ids,
454
455
        ) in enumerate(iterator):
            # Select next token
456
457
458
            next_token_id, logprobs = next_token_chooser(
                decoder_input_ids.view(1, -1), logits
            )
459
460

            # Append next token to decoder tokens
461
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
OlivierDehaene's avatar
OlivierDehaene committed
462
463
            new_decoder_input_length = decoder_input_length + 1

464
465
466
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
467
            next_token_text = self.decode_token(
468
469
                next_token_id_squeezed,
            )
470
471

            # Evaluate stopping criteria
472
473
            stop, reason = stopping_criteria(next_token_id, next_token_text)

474
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
475
476
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
477
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
478
479
480
481
482
483
484

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

485
486
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
487
488
                )
            else:
489
490
                # Keep request in the batch
                generated_text = None
491
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
492
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
493
494
495
496
497
498
499
500
501
502
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

503
504
505
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
506
507
508
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
509
510
511
512
513
514
515
516
517
518
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
519
                next_token_id_squeezed.item() in self.all_special_ids,
520
521
522
523
524
                generated_text,
            )

            generations.append(generation)

525
526
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
527
            return generations, None
528
529

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
530
531
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
532
        if len(next_batch_keep_indices) != len(batch):
533
            # Apply indices to decoder_attention mask, past key values and other items that need to be cached
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

566
        # Update decoder_attention_mask as we added a new token to input_ids
567
        if next_batch_decoder_attention_mask is not None:
568
            next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
569
570
571
572

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
573
            input_ids=None,
574
575
576
577
578
579
580
581
582
583
584
585
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
586
            padding_right_offset=batch.padding_right_offset - 1,
587
        )
588
        return generations, next_batch