seq2seq_lm.py 21.3 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
6
7
8
from typing import Optional, Tuple, List, Type

from text_generation.models import Model
9
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
10
from text_generation.pb import generate_pb2
11
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
12

13
14
tracer = trace.get_tracer(__name__)

15
16

@dataclass
17
class Seq2SeqLMBatch(Batch):
18
19
20
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
21
    # Encoder values
22
23
24
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
25
    # Decoder values
26
27
28
29
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
31
32
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
33
    # Lengths of all generations present in the batch
34
35
36
    input_lengths: List[int]
    decoder_input_lengths: List[int]

OlivierDehaene's avatar
OlivierDehaene committed
37
    # Generation helpers
38
39
40
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
41
    # Metadata used for padding
42
43
44
45
    size: int
    max_input_length: int
    max_decoder_input_length: int

46
    def to_pb(self) -> generate_pb2.Batch:
OlivierDehaene's avatar
OlivierDehaene committed
47
        """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
48
49
50
51
52
53
54
55
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
56
57
58
59
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
60
    ) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
61
        """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
62
63
64
65
66
67
68
69
70
71
72
73
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
        input_lengths = []

        decoder_input_ids = []
        decoder_input_lengths = []

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
            input_lengths.append(r.input_length)
OlivierDehaene's avatar
OlivierDehaene committed
74
            # Decoder sequence only contains the bos_token
75
76
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
77
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
78
            stopping_criterias.append(
79
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
80
81
            )

OlivierDehaene's avatar
OlivierDehaene committed
82
        # Tokenize batch
83
        pad_to_multiple_of = 8 if device.type == "cuda" else None
84
        tokenized_inputs = tokenizer(
85
86
87
88
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
89
            return_token_type_ids=False,
90
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
91
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
92
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
            max_input_length=max(input_lengths),
            max_decoder_input_length=1,
        )

    @classmethod
113
    @tracer.start_as_current_span("concatenate")
114
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
115
116
        """Concatenate multiple batches together by padding internal torch tensors"""

117
118
119
120
121
122
123
124
125
126
127
128
129
130
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_input_length = max(batch.max_input_length for batch in batches)
        max_decoder_input_length = max(
            batch.max_decoder_input_length for batch in batches
        )

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
131
        # Batch tensors
132
133
134
135
136
137
138
139
140
141
        input_ids = None
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
142

143
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
144
            # Extend all list attributes
145
146
147
148
149
150
151
152
153
154
155
156
157
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
158
            # Create padded tensor
159
            if input_ids is None:
160
                input_ids = batch.input_ids.new_zeros(
161
162
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
163
            # Copy to correct indices
164
165
166
167
            input_ids[
                start_index:end_index, -batch.max_input_length :
            ] = batch.input_ids[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
168
            # Create padded tensor
169
            if attention_mask is None:
170
                attention_mask = batch.attention_mask.new_zeros(
171
172
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
173
            # Copy to correct indices
174
175
176
177
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
178
            # Create padded tensor
179
            if decoder_input_ids is None:
180
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
181
182
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
183
            # Copy to correct indices
184
185
186
187
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
188
            # Create padded tensor
189
            if decoder_attention_mask is None:
190
191
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
192
193
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
194
195
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
196
197
198
199
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
200
            # If it exists, we need to index
201
202
203
204
205
            else:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
206
            # Create padded tensor
207
            if encoder_last_hidden_state is None:
208
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
209
210
211
212
213
214
215
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
216
            # Copy to correct indices
217
            encoder_last_hidden_state[
218
219
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
220

OlivierDehaene's avatar
OlivierDehaene committed
221
            # Iterate over attention layers
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
241
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
265
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
        )

291
292
293
    def __len__(self):
        return len(self.requests)

294
295

class Seq2SeqLM(Model):
296
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
297
298
299
300
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
301
302
303
            if quantize:
                raise ValueError("quantization is not available on CPU")

304
305
306
307
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
308
            model_id,
309
            revision=revision,
310
311
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
312
            load_in_8bit=quantize,
313
        ).eval()
314
        tokenizer = AutoTokenizer.from_pretrained(
315
            model_id, revision=revision, padding_side="left"
316
        )
317
318
319
320
321
322
323
324
325
326
327
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

328
329
330
    def decode(self, decoder_ids: List[int]) -> str:
        return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)

OlivierDehaene's avatar
OlivierDehaene committed
348
349
350
351
352
        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [encoder_last_hidden_state]

353
354
355
356
357
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
358
            encoder_outputs=encoder_last_hidden_state,
359
360
361
362
363
364
365
366
367
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

368
    @tracer.start_as_current_span("generate_token")
369
370
    def generate_token(
        self, batch: Seq2SeqLMBatch
371
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
372
373
374
375
376
377
378
        logits, encoder_last_hidden_state, past = self.forward(
            batch.input_ids,
            batch.attention_mask,
            batch.decoder_input_ids,
            batch.decoder_attention_mask,
            batch.encoder_last_hidden_state,
            batch.past_key_values,
379
380
381
382
383
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
384
        # New values for next forward
385
386
387
388
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
389
        # Metadata
390
391
392
393
394
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
395
        generations: List[Generation] = []
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.input_ids,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
            input_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
418
            decoder_input_ids,
419
420
        ) in enumerate(iterator):
            # Select next token
421
422
423
            next_token_id, logprobs = next_token_chooser(
                decoder_input_ids.view(1, -1), logits
            )
424
425

            # Append next token to decoder tokens
426
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
OlivierDehaene's avatar
OlivierDehaene committed
427
428
            new_decoder_input_length = decoder_input_length + 1

429
430
431
432
433
434
435
436
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_text = self.tokenizer.decode(
                next_token_id_squeezed,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=False,
            )
437
438

            # Evaluate stopping criteria
439
440
            stop, reason = stopping_criteria(next_token_id, next_token_text)

441
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
442
443
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
444
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
445
446
447
448
449
450
451

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

452
453
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
454
455
                )
            else:
456
457
                # Keep request in the batch
                generated_text = None
458
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
459
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
460
461
462
463
464
465
466
467
468
469
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, [float("nan")], prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
                generated_text,
            )

            generations.append(generation)

495
496
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
497
            return generations, None
498
499

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
500
501
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
502
        if len(next_batch_keep_indices) != len(batch):
OlivierDehaene's avatar
OlivierDehaene committed
503
            # Apply indices to attention mask, past key values and other items that need to be cached
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
            next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]

            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_input_ids = batch.input_ids
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

OlivierDehaene's avatar
OlivierDehaene committed
539
        # Update decoder_attention_mask with padding as we added a new token to input_ids
540
541
542
543
        if next_batch_decoder_attention_mask is not None:
            next_batch_decoder_attention_mask = torch.cat(
                [
                    next_batch_decoder_attention_mask,
544
                    next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
                ],
                dim=1,
            )

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
        )
566
        return generations, next_batch