seq2seq_lm.py 22.9 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
7
from typing import Optional, Tuple, List, Type

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    GeneratedText,
    Batch,
    Generation,
    PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class Seq2SeqLMBatch(Batch):
23
24
25
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
26
    # Encoder values
27
28
29
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Decoder values
31
32
33
34
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
35
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
36
37
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
38
    # Lengths of all generations present in the batch
39
40
    input_lengths: List[int]
    decoder_input_lengths: List[int]
41
42
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
43

OlivierDehaene's avatar
OlivierDehaene committed
44
    # Generation helpers
45
46
47
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
48
    # Metadata used for padding
49
50
51
    size: int
    max_input_length: int
    max_decoder_input_length: int
52
    padding_right_offset: int
53

54
    def to_pb(self) -> generate_pb2.Batch:
55
        """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
56
57
58
59
60
61
62
63
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
64
65
66
67
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
68
    ) -> "Seq2SeqLMBatch":
69
        """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
70
71
72
73
74
75
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        decoder_input_ids = []
        decoder_input_lengths = []
76
77
        offsets = []
        token_offsets = []
78
79

        # Parse batch
80
        max_truncation = 0
81
        padding_right_offset = 0
82
83
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
84
            # Decoder sequence only contains the bos_token
85
86
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
87
88
            offsets.append(None)
            token_offsets.append(None)
89
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
90
91
92
93
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
94
            max_truncation = max(max_truncation, r.truncate)
95
96
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
97
98
            )

OlivierDehaene's avatar
OlivierDehaene committed
99
        # Tokenize batch
100
        tokenized_inputs = tokenizer(
101
102
103
            inputs,
            return_tensors="pt",
            padding=True,
104
            return_token_type_ids=False,
105
106
            truncation=True,
            max_length=max_truncation,
107
        ).to(device)
108
109
110
111

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

OlivierDehaene's avatar
OlivierDehaene committed
112
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
113
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
114
115
116
117
118
119
120
121
122
123

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
124
            input_lengths=input_lengths.tolist(),
125
            decoder_input_lengths=decoder_input_lengths,
126
127
            offsets=offsets,
            token_offsets=token_offsets,
128
129
130
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
131
            max_input_length=max_input_length.item(),
132
            max_decoder_input_length=1,
133
            padding_right_offset=padding_right_offset,
134
135
136
        )

    @classmethod
137
    @tracer.start_as_current_span("concatenate")
138
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
139
140
        """Concatenate multiple batches together by padding internal torch tensors"""

141
        # Used for padding
142
143
144
145
146
147
148
149
150
151
152
        total_batch_size = 0
        max_input_length = 0
        max_decoder_input_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_input_length = max(max_input_length, batch.max_input_length)
            max_decoder_input_length = max(
                max_decoder_input_length, batch.max_decoder_input_length
            )
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
153
154
155
156
157

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
158
159
        offsets = []
        token_offsets = []
160
161
162
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
163
        # Batch tensors
164
165
166
167
168
169
170
171
172
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
173

174
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
175
            # Extend all list attributes
176
177
178
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
179
180
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
181
182
183
184
185
186
187
188
189
190
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
191
            # Create padded tensor
192
            if attention_mask is None:
193
                attention_mask = batch.attention_mask.new_zeros(
194
195
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
196
            # Copy to correct indices
197
198
199
200
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
201
            # Create padded tensor
202
            if decoder_input_ids is None:
203
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
204
205
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
206
            # Copy to correct indices
207
208
209
210
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
211
            # Create padded tensor
212
            if decoder_attention_mask is None:
213
214
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
215
                    (total_batch_size, max_decoder_input_length + padding_right_offset),
216
                )
OlivierDehaene's avatar
OlivierDehaene committed
217
218
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
219
            left_offset = max_decoder_input_length - batch.max_decoder_input_length
220
221
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
222
223
                    start_index:end_index,
                    left_offset:-padding_right_offset,
224
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
225
            # If it exists, we need to index
226
            else:
227
228
                batch_left_offset = (
                    batch.decoder_attention_mask.shape[1]
229
230
                    - batch.max_decoder_input_length
                    - batch.padding_right_offset
231
                )
232
                decoder_attention_mask[
233
234
235
236
237
238
                    start_index:end_index,
                    left_offset:-padding_right_offset,
                ] = batch.decoder_attention_mask[
                    :,
                    batch_left_offset : -batch.padding_right_offset,
                ]
239

OlivierDehaene's avatar
OlivierDehaene committed
240
            # Create padded tensor
241
            if encoder_last_hidden_state is None:
242
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
243
244
245
246
247
248
249
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
250
            # Copy to correct indices
251
            encoder_last_hidden_state[
252
253
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
254

OlivierDehaene's avatar
OlivierDehaene committed
255
            # Iterate over attention layers
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
275
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
299
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
300
301
302
303
304
305
306
307
308
309

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
310
            input_ids=None,
311
312
313
314
315
316
317
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
318
319
            offsets=offsets,
            token_offsets=token_offsets,
320
321
322
323
324
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
325
            padding_right_offset=padding_right_offset,
326
327
        )

328
329
330
    def __len__(self):
        return len(self.requests)

331
332

class Seq2SeqLM(Model):
333
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
334
335
336
337
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
338
339
340
            if quantize:
                raise ValueError("quantization is not available on CPU")

341
342
343
344
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
345
            model_id,
346
            revision=revision,
347
348
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
349
            load_in_8bit=quantize,
350
        ).eval()
351
        tokenizer = AutoTokenizer.from_pretrained(
352
            model_id, revision=revision, padding_side="left", truncation_side="left"
353
        )
354
355
356
357
358
359
360
361
362
363
364
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

365
    def decode(self, decoder_ids: List[int]) -> str:
366
367
368
        return self.tokenizer.decode(
            decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
389
            encoder_outputs=encoder_last_hidden_state,
390
391
392
393
394
395
396
397
398
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

399
    @tracer.start_as_current_span("generate_token")
400
401
    def generate_token(
        self, batch: Seq2SeqLMBatch
402
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        if batch.decoder_attention_mask is not None:
            # slice to the correct shape
            decoder_attention_mask = batch.decoder_attention_mask[
                :, : -batch.padding_right_offset
            ]
        else:
            decoder_attention_mask = None

        # check if first forward or not
        if batch.past_key_values is not None:
            # Only take the last token
            decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
        else:
            decoder_input_ids = batch.decoder_input_ids

        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if batch.encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [batch.encoder_last_hidden_state]
        else:
            encoder_last_hidden_state = batch.encoder_last_hidden_state

425
426
427
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
428
429
430
            decoder_input_ids,
            decoder_attention_mask,
            encoder_last_hidden_state,
431
            batch.past_key_values,
432
433
434
435
436
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
437
        # New values for next forward
438
        next_batch_input_lengths = []
439
440
        next_batch_offsets = []
        next_batch_token_offsets = []
441
442
443
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
444
        # Metadata
445
446
447
448
449
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
450
        generations: List[Generation] = []
451
452
453
454
455

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
456
457
            batch.offsets,
            batch.token_offsets,
458
459
460
461
462
463
464
465
466
467
468
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
469
470
            offset,
            token_offset,
471
472
473
474
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
475
            decoder_input_ids,
476
477
        ) in enumerate(iterator):
            # Select next token
478
479
480
            next_token_id, logprobs = next_token_chooser(
                decoder_input_ids.view(1, -1), logits
            )
481
482

            # Append next token to decoder tokens
483
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
OlivierDehaene's avatar
OlivierDehaene committed
484
485
            new_decoder_input_length = decoder_input_length + 1

486
487
488
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
489
490
            next_token_text, offset, token_offset = self.decode_token(
                decoder_input_ids, offset, token_offset
491
            )
492
493

            # Evaluate stopping criteria
494
495
            stop, reason = stopping_criteria(next_token_id, next_token_text)

496
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
497
498
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
499
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
500
501
502
503
504
505
506

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

507
508
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
509
510
                )
            else:
511
512
                # Keep request in the batch
                generated_text = None
513
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
514
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
515
516
517
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
518
519
                next_batch_offsets.append(offset)
                next_batch_token_offsets.append(token_offset)
520
521
522
523
524
525
526
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

527
528
529
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_tokens = PrefillTokens(
530
531
532
                    [self.tokenizer.bos_token_id],
                    [float("nan")],
                    [self.tokenizer.bos_token],
533
534
535
536
537
538
539
540
541
542
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
543
                next_token_id_squeezed.item() in self.all_special_ids,
544
545
546
547
548
                generated_text,
            )

            generations.append(generation)

549
550
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
551
            return generations, None
552
553

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
554
555
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
556
        if len(next_batch_keep_indices) != len(batch):
557
            # Apply indices to decoder_attention mask, past key values and other items that need to be cached
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

590
        # Update decoder_attention_mask as we added a new token to input_ids
591
        if next_batch_decoder_attention_mask is not None:
592
            next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
593
594
595
596

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
597
            input_ids=None,
598
599
600
601
602
603
604
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
605
606
            offsets=next_batch_offsets,
            token_offsets=next_batch_token_offsets,
607
608
609
610
611
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
612
            padding_right_offset=batch.padding_right_offset - 1,
613
        )
614
        return generations, next_batch