seq2seq_lm.py 21.2 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
5
6
7
from typing import Optional, Tuple, List, Type

from text_generation.models import Model
8
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
9
from text_generation.pb import generate_pb2
10
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
11
12
13


@dataclass
14
class Seq2SeqLMBatch(Batch):
15
16
17
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
18
    # Encoder values
19
20
21
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
22
    # Decoder values
23
24
25
26
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
28
29
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Lengths of all generations present in the batch
31
32
33
    input_lengths: List[int]
    decoder_input_lengths: List[int]

OlivierDehaene's avatar
OlivierDehaene committed
34
    # Generation helpers
35
36
37
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
38
    # Metadata used for padding
39
40
41
42
    size: int
    max_input_length: int
    max_decoder_input_length: int

43
    def to_pb(self) -> generate_pb2.Batch:
OlivierDehaene's avatar
OlivierDehaene committed
44
        """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
45
46
47
48
49
50
51
52
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
53
54
55
56
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
57
    ) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
58
        """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
59
60
61
62
63
64
65
66
67
68
69
70
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
        input_lengths = []

        decoder_input_ids = []
        decoder_input_lengths = []

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
            input_lengths.append(r.input_length)
OlivierDehaene's avatar
OlivierDehaene committed
71
            # Decoder sequence only contains the bos_token
72
73
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
74
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
75
            stopping_criterias.append(
76
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
77
78
            )

OlivierDehaene's avatar
OlivierDehaene committed
79
        # Tokenize batch
80
        pad_to_multiple_of = 8 if device.type == "cuda" else None
81
        tokenized_inputs = tokenizer(
82
83
84
85
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
86
            return_token_type_ids=False,
87
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
88
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
89
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
            max_input_length=max(input_lengths),
            max_decoder_input_length=1,
        )

    @classmethod
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
111
112
        """Concatenate multiple batches together by padding internal torch tensors"""

113
114
115
116
117
118
119
120
121
122
123
124
125
126
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_input_length = max(batch.max_input_length for batch in batches)
        max_decoder_input_length = max(
            batch.max_decoder_input_length for batch in batches
        )

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
127
        # Batch tensors
128
129
130
131
132
133
134
135
136
137
        input_ids = None
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
138

139
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
140
            # Extend all list attributes
141
142
143
144
145
146
147
148
149
150
151
152
153
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
154
            # Create padded tensor
155
            if input_ids is None:
156
                input_ids = batch.input_ids.new_zeros(
157
158
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
159
            # Copy to correct indices
160
161
162
163
            input_ids[
                start_index:end_index, -batch.max_input_length :
            ] = batch.input_ids[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
164
            # Create padded tensor
165
            if attention_mask is None:
166
                attention_mask = batch.attention_mask.new_zeros(
167
168
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
169
            # Copy to correct indices
170
171
172
173
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
174
            # Create padded tensor
175
            if decoder_input_ids is None:
176
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
177
178
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
179
            # Copy to correct indices
180
181
182
183
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
184
            # Create padded tensor
185
            if decoder_attention_mask is None:
186
187
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
188
189
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
190
191
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
192
193
194
195
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
196
            # If it exists, we need to index
197
198
199
200
201
            else:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
202
            # Create padded tensor
203
            if encoder_last_hidden_state is None:
204
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
205
206
207
208
209
210
211
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
212
            # Copy to correct indices
213
            encoder_last_hidden_state[
214
215
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
216

OlivierDehaene's avatar
OlivierDehaene committed
217
            # Iterate over attention layers
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
237
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
261
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
        )

287
288
289
    def __len__(self):
        return len(self.requests)

290
291

class Seq2SeqLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
292
    def __init__(self, model_name: str, quantize=False):
293
294
295
296
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
297
298
299
            if quantize:
                raise ValueError("quantization is not available on CPU")

300
301
302
303
304
305
306
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
307
            load_in_8bit=quantize,
308
309
310
311
312
313
314
315
316
317
318
319
320
        ).eval()
        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

321
322
323
    def decode(self, decoder_ids: List[int]) -> str:
        return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)

OlivierDehaene's avatar
OlivierDehaene committed
341
342
343
344
345
        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [encoder_last_hidden_state]

346
347
348
349
350
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
351
            encoder_outputs=encoder_last_hidden_state,
352
353
354
355
356
357
358
359
360
361
362
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

    def generate_token(
        self, batch: Seq2SeqLMBatch
363
    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
            logits, encoder_last_hidden_state, past = self.forward(
                batch.input_ids,
                batch.attention_mask,
                batch.decoder_input_ids,
                batch.decoder_attention_mask,
                batch.encoder_last_hidden_state,
                batch.past_key_values,
            )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
381
        # New values for next forward
382
383
384
385
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []

OlivierDehaene's avatar
OlivierDehaene committed
386
        # Metadata
387
388
389
390
391
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
392
        generations: List[Generation] = []
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.input_ids,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
            input_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
415
            decoder_input_ids,
416
417
        ) in enumerate(iterator):
            # Select next token
418
            next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
419
420

            # Append next token to decoder tokens
421
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
422
423
            new_decoder_input_length = decoder_input_length + 1

424
425
426
427
428
429
430
431
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_text = self.tokenizer.decode(
                next_token_id_squeezed,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=False,
            )
432
433

            # Evaluate stopping criteria
434
435
            stop, reason = stopping_criteria(next_token_id, next_token_text)

436
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
437
438
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
439
                output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
440
441
442
443
444
445
446

                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

447
448
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
449
450
                )
            else:
451
452
                # Keep request in the batch
                generated_text = None
453
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
454
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
455
456
457
458
459
460
461
462
463
464
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            # Prefill
            if stopping_criteria.current_tokens == 1:
                prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, [float("nan")], prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
                generated_text,
            )

            generations.append(generation)

490
491
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
492
            return generations, None
493
494

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
495
496
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
497
        if len(next_batch_keep_indices) != len(batch):
OlivierDehaene's avatar
OlivierDehaene committed
498
            # Apply indices to attention mask, past key values and other items that need to be cached
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]

            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_input_ids = batch.input_ids
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

OlivierDehaene's avatar
OlivierDehaene committed
534
        # Update decoder_attention_mask with padding as we added a new token to input_ids
535
536
537
538
        if next_batch_decoder_attention_mask is not None:
            next_batch_decoder_attention_mask = torch.cat(
                [
                    next_batch_decoder_attention_mask,
539
                    next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
                ],
                dim=1,
            )

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
        )
561
        return generations, next_batch