seq2seq_lm.py 21.9 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
7
from typing import Optional, Tuple, List, Type

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
25
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
26
    # Encoder values
27
28
29
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Decoder values
31
32
33
34
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
35
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
36
37
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
38
    # Lengths of all generations present in the batch
39
40
41
    input_lengths: List[int]
    decoder_input_lengths: List[int]

OlivierDehaene's avatar
OlivierDehaene committed
42
    # Generation helpers
43
44
45
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
46
    # Metadata used for padding
47
48
49
    size: int
    max_input_length: int
    max_decoder_input_length: int
50
    padding_right_offset: int
51

52
    def to_pb(self) -> generate_pb2.Batch:
53
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
54
55
56
57
58
59
60
61
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
62
63
64
65
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
66
    ) -> "Seq2SeqLMBatch":
67
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
68
69
70
71
72
73
74
75
76
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
        input_lengths = []

        decoder_input_ids = []
        decoder_input_lengths = []

        # Parse batch
77
78
        max_input_length = 0
        padding_right_offset = 0
79
80
81
        for r in pb.requests:
            inputs.append(r.inputs)
            input_lengths.append(r.input_length)
OlivierDehaene's avatar
OlivierDehaene committed
82
            # Decoder sequence only contains the bos_token
83
84
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
85
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
86
87
88
89
90
91
92
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            max_input_length = max(max_input_length, r.input_length)
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
93
94
            )

OlivierDehaene's avatar
OlivierDehaene committed
95
        # Tokenize batch
96
        tokenized_inputs = tokenizer(
97
98
99
            inputs,
            return_tensors="pt",
            padding=True,
100
            return_token_type_ids=False,
101
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
102
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
103
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
            max_input_length=max(input_lengths),
            max_decoder_input_length=1,
121
            padding_right_offset=padding_right_offset,
122
123
124
        )

    @classmethod
125
    @tracer.start_as_current_span("concatenate")
126
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
127
128
        """Concatenate multiple batches together by padding internal torch tensors"""

129
        # Used for padding
130
131
132
133
134
135
136
137
138
139
140
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
141
142
143
144
145
146
147
148

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
149
        # Batch tensors
150
151
152
153
154
155
156
157
158
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
159

160
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
161
            # Extend all list attributes
162
163
164
165
166
167
168
169
170
171
172
173
174
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
175
            # Create padded tensor
176
            if attention_mask is None:
177
                attention_mask = batch.attention_mask.new_zeros(
178
179
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
180
            # Copy to correct indices
181
182
183
184
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
185
            # Create padded tensor
186
            if decoder_input_ids is None:
187
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
188
189
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
190
            # Copy to correct indices
191
192
193
194
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
195
            # Create padded tensor
196
            if decoder_attention_mask is None:
197
198
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
199
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
200
                )
OlivierDehaene's avatar
OlivierDehaene committed
201
202
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
203
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
204
205
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
206
207
                    start_index:end_index,
                    left_offset:-padding_right_offset,
208
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
209
            # If it exists, we need to index
210
            else:
211
212
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
213
214
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
215
                )
216
                decoder_attention_mask[
217
218
219
220
221
222
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
223

OlivierDehaene's avatar
OlivierDehaene committed
224
            # Create padded tensor
225
            if encoder_last_hidden_state is None:
226
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
227
228
229
230
231
232
233
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
234
            # Copy to correct indices
235
            encoder_last_hidden_state[
236
237
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
238

OlivierDehaene's avatar
OlivierDehaene committed
239
            # Iterate over attention layers
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
259
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
283
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
284
285
286
287
288
289
290
291
292
293

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
294
            input_ids=None,
295
296
297
298
299
300
301
302
303
304
305
306
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
307
            padding_right_offset=padding_right_offset,
308
309
        )

310
311
312
    def __len__(self):
        return len(self.requests)

313
314

class Seq2SeqLM(Model):
315
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
316
317
318
319
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
320
321
322
            if quantize:
                raise ValueError("quantization is not available on CPU")

323
324
325
326
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
327
            model_id,
328
            revision=revision,
329
330
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
331
            load_in_8bit=quantize,
332
        ).eval()
333
        tokenizer = AutoTokenizer.from_pretrained(
334
            model_id, revision=revision, padding_side="left"
335
        )
336
337
338
339
340
341
342
343
344
345
346
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

347
    def decode(self, decoder_ids: List[int]) -> str:
348
349
350
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
351

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
371
            encoder_outputs=encoder_last_hidden_state,
372
373
374
375
376
377
378
379
380
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

381
    @tracer.start_as_current_span("generate_token")
382
383
    def generate_token(
        self, batch: Seq2SeqLMBatch
384
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # check if first forward or not
        if batch.past_key_values is not None:
            # Only take the last token
            decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
        else:
            decoder_input_ids = batch.decoder_input_ids

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
            encoder_last_hidden_state = batch.encoder_last_hidden_state

407
408
409
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
410
411
412
            decoder_input_ids,
            decoder_attention_mask,
            encoder_last_hidden_state,
413
            batch.past_key_values,
414
415
416
417
418
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
419
        # New values for next forward
420
421
422
423
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
424
        # Metadata
425
426
427
428
429
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
430
        generations: List[Generation] = []
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
451
            decoder_input_ids,
452
453
        ) in enumerate(iterator):
            # Select next token
454
455
456
            next_token_id, logprobs = next_token_chooser(
                decoder_input_ids.view(1, -1), logits
            )
457
458

            # Append next token to decoder tokens
459
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
OlivierDehaene's avatar
OlivierDehaene committed
460
461
            new_decoder_input_length = decoder_input_length + 1

462
463
464
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
465
            next_token_text = self.decode_token(
466
467
                next_token_id_squeezed,
            )
468
469

            # Evaluate stopping criteria
470
471
            stop, reason = stopping_criteria(next_token_id, next_token_text)

472
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
473
474
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
475
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
476
477
478
479
480
481
482

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

483
484
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
485
486
                )
            else:
487
488
                # Keep request in the batch
                generated_text = None
489
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
490
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
491
492
493
494
495
496
497
498
499
500
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

501
502
503
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
504
505
506
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
507
508
509
510
511
512
513
514
515
516
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
517
                next_token_id_squeezed.item() in self.all_special_ids,
518
519
520
521
522
                generated_text,
            )

            generations.append(generation)

523
524
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
525
            return generations, None
526
527

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
528
529
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
530
        if len(next_batch_keep_indices) != len(batch):
531
            # Apply indices to decoder_attention mask, past key values and other items that need to be cached
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

564
        # Update decoder_attention_mask as we added a new token to input_ids
565
        if next_batch_decoder_attention_mask is not None:
566
            next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
567
568
569
570

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
571
            input_ids=None,
572
573
574
575
576
577
578
579
580
581
582
583
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
584
            padding_right_offset=batch.padding_right_offset - 1,
585
        )
586
        return generations, next_batch