seq2seq_lm.py 22 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
7
from typing import Optional, Tuple, List, Type

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
25
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
26
    # Encoder values
27
28
29
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Decoder values
31
32
33
34
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
35
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
36
37
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
38
    # Lengths of all generations present in the batch
39
40
41
    input_lengths: List[int]
    decoder_input_lengths: List[int]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Generation helpers
43
44
45
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
46
    # Metadata used for padding
47
48
49
    size: int
    max_input_length: int
    max_decoder_input_length: int
50
    padding_right_offset: int
51

52
    def to_pb(self) -> generate_pb2.Batch:
53
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
54
55
56
57
58
59
60
61
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
62
63
64
65
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
66
    ) -> "Seq2SeqLMBatch":
67
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
68
69
70
71
72
73
74
75
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_ids = []
        decoder_input_lengths = []

        # Parse batch
76
        max_truncation = 0
77
        padding_right_offset = 0
78
79
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
80
            # Decoder sequence only contains the bos_token
81
82
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
83
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
84
85
86
87
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
88
            max_truncation = max(max_truncation, r.truncate)
89
90
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
91
92
            )

OlivierDehaene's avatar
OlivierDehaene committed
93
        # Tokenize batch
94
        tokenized_inputs = tokenizer(
95
96
97
            inputs,
            return_tensors="pt",
            padding=True,
98
            return_token_type_ids=False,
99
100
            truncation=True,
            max_length=max_truncation,
101
        ).to(device)
102
103
104
105

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

OlivierDehaene's avatar
OlivierDehaene committed
106
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
107
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
108
109
110
111
112
113
114
115
116
117

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
118
            input_lengths=input_lengths.tolist(),
119
120
121
122
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
123
            max_input_length=max_input_length.item(),
124
            max_decoder_input_length=1,
125
            padding_right_offset=padding_right_offset,
126
127
128
        )

    @classmethod
129
    @tracer.start_as_current_span("concatenate")
130
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
131
132
        """Concatenate multiple batches together by padding internal torch tensors"""

133
        # Used for padding
134
135
136
137
138
139
140
141
142
143
144
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
145
146
147
148
149
150
151
152

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
153
        # Batch tensors
154
155
156
157
158
159
160
161
162
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
163

164
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
165
            # Extend all list attributes
166
167
168
169
170
171
172
173
174
175
176
177
178
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
179
            # Create padded tensor
180
            if attention_mask is None:
181
                attention_mask = batch.attention_mask.new_zeros(
182
183
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
184
            # Copy to correct indices
185
186
187
188
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
189
            # Create padded tensor
190
            if decoder_input_ids is None:
191
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
192
193
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
194
            # Copy to correct indices
195
196
197
198
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
199
            # Create padded tensor
200
            if decoder_attention_mask is None:
201
202
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
203
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
204
                )
OlivierDehaene's avatar
OlivierDehaene committed
205
206
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
207
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
208
209
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
210
211
                    start_index:end_index,
                    left_offset:-padding_right_offset,
212
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
213
            # If it exists, we need to index
214
            else:
215
216
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
217
218
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
219
                )
220
                decoder_attention_mask[
221
222
223
224
225
226
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
227

OlivierDehaene's avatar
OlivierDehaene committed
228
            # Create padded tensor
229
            if encoder_last_hidden_state is None:
230
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
231
232
233
234
235
236
237
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
238
            # Copy to correct indices
239
            encoder_last_hidden_state[
240
241
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
242

OlivierDehaene's avatar
OlivierDehaene committed
243
            # Iterate over attention layers
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
263
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
287
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
288
289
290
291
292
293
294
295
296
297

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
298
            input_ids=None,
299
300
301
302
303
304
305
306
307
308
309
310
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
311
            padding_right_offset=padding_right_offset,
312
313
        )

314
315
316
    def __len__(self):
        return len(self.requests)

317
318

class Seq2SeqLM(Model):
319
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
320
321
322
323
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
324
325
326
            if quantize:
                raise ValueError("quantization is not available on CPU")

327
328
329
330
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
331
            model_id,
332
            revision=revision,
333
334
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
335
            load_in_8bit=quantize,
336
        ).eval()
337
        tokenizer = AutoTokenizer.from_pretrained(
338
            model_id, revision=revision, padding_side="left"
339
        )
340
341
342
343
344
345
346
347
348
349
350
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

351
    def decode(self, decoder_ids: List[int]) -> str:
352
353
354
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
375
            encoder_outputs=encoder_last_hidden_state,
376
377
378
379
380
381
382
383
384
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

385
    @tracer.start_as_current_span("generate_token")
386
387
    def generate_token(
        self, batch: Seq2SeqLMBatch
388
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # check if first forward or not
        if batch.past_key_values is not None:
            # Only take the last token
            decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
        else:
            decoder_input_ids = batch.decoder_input_ids

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
            encoder_last_hidden_state = batch.encoder_last_hidden_state

411
412
413
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
414
415
416
            decoder_input_ids,
            decoder_attention_mask,
            encoder_last_hidden_state,
417
            batch.past_key_values,
418
419
420
421
422
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
423
        # New values for next forward
424
425
426
427
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
428
        # Metadata
429
430
431
432
433
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
434
        generations: List[Generation] = []
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
455
            decoder_input_ids,
456
457
        ) in enumerate(iterator):
            # Select next token
458
459
460
            next_token_id, logprobs = next_token_chooser(
                decoder_input_ids.view(1, -1), logits
            )
461
462

            # Append next token to decoder tokens
463
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
OlivierDehaene's avatar
OlivierDehaene committed
464
465
            new_decoder_input_length = decoder_input_length + 1

466
467
468
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
469
            next_token_text = self.decode_token(
470
471
                next_token_id_squeezed,
            )
472
473

            # Evaluate stopping criteria
474
475
            stop, reason = stopping_criteria(next_token_id, next_token_text)

476
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
477
478
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
479
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
480
481
482
483
484
485
486

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

487
488
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
489
490
                )
            else:
491
492
                # Keep request in the batch
                generated_text = None
493
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
494
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
495
496
497
498
499
500
501
502
503
504
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

505
506
507
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
508
509
510
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
511
512
513
514
515
516
517
518
519
520
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
521
                next_token_id_squeezed.item() in self.all_special_ids,
522
523
524
525
526
                generated_text,
            )

            generations.append(generation)

527
528
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
529
            return generations, None
530
531

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
532
533
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
534
        if len(next_batch_keep_indices) != len(batch):
535
            # Apply indices to decoder_attention mask, past key values and other items that need to be cached
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

568
        # Update decoder_attention_mask as we added a new token to input_ids
569
        if next_batch_decoder_attention_mask is not None:
570
            next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
571
572
573
574

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
575
            input_ids=None,
576
577
578
579
580
581
582
583
584
585
586
587
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
588
            padding_right_offset=batch.padding_right_offset - 1,
589
        )
590
        return generations, next_batch