seq2seq_lm.py 21.1 KB
Newer Older
1
2
3
import torch

from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
5
6
7
from typing import Optional, Tuple, List, Type

from text_generation.models import Model
8
from text_generation.models.types import GeneratedText, Batch
9
10
11
12
13
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria


@dataclass
14
class Seq2SeqLMBatch(Batch):
15
16
17
    batch_id: int
    requests: List[generate_pb2.Request]

OlivierDehaene's avatar
OlivierDehaene committed
18
    # Encoder values
19
20
21
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

OlivierDehaene's avatar
OlivierDehaene committed
22
    # Decoder values
23
24
25
26
    decoder_input_ids: torch.Tensor
    decoder_attention_mask: Optional[torch.Tensor]
    encoder_last_hidden_state: Optional[torch.Tensor]

OlivierDehaene's avatar
OlivierDehaene committed
27
    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
28
29
    past_key_values: Optional[List[Tuple]]

OlivierDehaene's avatar
OlivierDehaene committed
30
    # Lengths of all generations present in the batch
31
32
    input_lengths: List[int]
    decoder_input_lengths: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
33
    decoder_logprobs: List[Optional[torch.Tensor]]
34

OlivierDehaene's avatar
OlivierDehaene committed
35
    # Generation helpers
36
37
38
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

OlivierDehaene's avatar
OlivierDehaene committed
39
    # Metadata used for padding
40
41
42
43
    size: int
    max_input_length: int
    max_decoder_input_length: int

44
    def to_pb(self) -> generate_pb2.Batch:
OlivierDehaene's avatar
OlivierDehaene committed
45
        """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
46
47
48
49
50
51
52
53
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
54
        cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
55
    ) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
56
        """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
57
58
59
60
61
62
63
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
        input_lengths = []

        decoder_input_ids = []
        decoder_input_lengths = []
OlivierDehaene's avatar
OlivierDehaene committed
64
        decoder_logprobs = []
65
66
67
68
69

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
            input_lengths.append(r.input_length)
OlivierDehaene's avatar
OlivierDehaene committed
70
            # Decoder sequence only contains the bos_token
71
72
            decoder_input_ids.append(tokenizer.bos_token_id)
            decoder_input_lengths.append(1)
73
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
74
            stopping_criterias.append(
75
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
76
            )
OlivierDehaene's avatar
OlivierDehaene committed
77
            decoder_logprobs.append(None)
78

OlivierDehaene's avatar
OlivierDehaene committed
79
        # Tokenize batch
80
        pad_to_multiple_of = 8 if device.type == "cuda" else None
81
        tokenized_inputs = tokenizer(
82
83
84
85
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
86
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
87
        # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
88
        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
89
90
91
92
93
94
95
96
97
98
99
100

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=None,
            encoder_last_hidden_state=None,
            past_key_values=None,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
OlivierDehaene's avatar
OlivierDehaene committed
101
            decoder_logprobs=decoder_logprobs,
102
103
104
105
106
107
108
109
110
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=len(pb.requests),
            max_input_length=max(input_lengths),
            max_decoder_input_length=1,
        )

    @classmethod
    def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
OlivierDehaene's avatar
OlivierDehaene committed
111
112
        """Concatenate multiple batches together by padding internal torch tensors"""

113
114
115
116
117
118
119
120
121
122
123
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_input_length = max(batch.max_input_length for batch in batches)
        max_decoder_input_length = max(
            batch.max_decoder_input_length for batch in batches
        )

        # Batch attributes
        requests = []
        input_lengths = []
        decoder_input_lengths = []
OlivierDehaene's avatar
OlivierDehaene committed
124
        decoder_logprobs = []
125
126
127
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
128
        # Batch tensors
129
130
131
132
133
134
135
136
137
138
        input_ids = None
        attention_mask = None
        decoder_input_ids = None
        decoder_attention_mask = None
        encoder_last_hidden_state = None
        past_key_values = []

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
OlivierDehaene's avatar
OlivierDehaene committed
139

140
        for i, batch in enumerate(batches):
OlivierDehaene's avatar
OlivierDehaene committed
141
            # Extend all list attributes
142
143
144
            requests.extend(batch.requests)
            input_lengths.extend(batch.input_lengths)
            decoder_input_lengths.extend(batch.decoder_input_lengths)
OlivierDehaene's avatar
OlivierDehaene committed
145
            decoder_logprobs.extend(batch.decoder_logprobs)
146
147
148
149
150
151
152
153
154
155
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.encoder_last_hidden_state is None:
                raise ValueError("Batch encoder_last_hidden_state cannot be None")

OlivierDehaene's avatar
OlivierDehaene committed
156
            # Create padded tensor
157
            if input_ids is None:
158
                input_ids = batch.input_ids.new_zeros(
159
160
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
161
            # Copy to correct indices
162
163
164
165
            input_ids[
                start_index:end_index, -batch.max_input_length :
            ] = batch.input_ids[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
166
            # Create padded tensor
167
            if attention_mask is None:
168
                attention_mask = batch.attention_mask.new_zeros(
169
170
                    (total_batch_size, max_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
171
            # Copy to correct indices
172
173
174
175
            attention_mask[
                start_index:end_index, -batch.max_input_length :
            ] = batch.attention_mask[:, -batch.max_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
176
            # Create padded tensor
177
            if decoder_input_ids is None:
178
                decoder_input_ids = batch.decoder_input_ids.new_zeros(
179
180
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
181
            # Copy to correct indices
182
183
184
185
            decoder_input_ids[
                start_index:end_index, -batch.max_decoder_input_length :
            ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
186
            # Create padded tensor
187
            if decoder_attention_mask is None:
188
189
                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
                decoder_attention_mask = batch.attention_mask.new_zeros(
190
191
                    (total_batch_size, max_decoder_input_length),
                )
OlivierDehaene's avatar
OlivierDehaene committed
192
193
            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
            # this batch. All generations are of length `batch.max_decoder_input_length`.
194
195
196
197
            if batch.decoder_attention_mask is None:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = 1
OlivierDehaene's avatar
OlivierDehaene committed
198
            # If it exists, we need to index
199
200
201
202
203
            else:
                decoder_attention_mask[
                    start_index:end_index, -batch.max_decoder_input_length :
                ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]

OlivierDehaene's avatar
OlivierDehaene committed
204
            # Create padded tensor
205
            if encoder_last_hidden_state is None:
206
                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
207
208
209
210
211
212
213
                    (
                        total_batch_size,
                        max_input_length,
                        batch.encoder_last_hidden_state.shape[-1],
                    ),
                )

OlivierDehaene's avatar
OlivierDehaene committed
214
            # Copy to correct indices
215
            encoder_last_hidden_state[
216
217
                start_index:end_index, -batch.max_input_length :, :
            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
218

OlivierDehaene's avatar
OlivierDehaene committed
219
            # Iterate over attention layers
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            for j, past in enumerate(batch.past_key_values):
                _, num_heads, _, head_dim = past[0].shape

                # This will run only once per layer
                if j == len(past_key_values):
                    past_key_values.append([])

                # Decoder past
                for k, t in enumerate(past[:2]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        (max_decoder_input_length - 1),
                        head_dim,
                    )

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if k == len(past_key_values[j]):
239
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

                    # We slice the past keys and values to remove the padding from previous batches
                    past_key_values[j][k][
                        start_index:end_index,
                        :,
                        -(batch.max_decoder_input_length - 1) :,
                        :,
                    ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]

                # encoder past
                for k, t in enumerate(past[2:]):
                    padded_t_shape = (
                        total_batch_size,
                        num_heads,
                        max_input_length,
                        head_dim,
                    )

                    idx = k + 2

                    # Initialize tensors
                    # This will run only once per layer and per past tensor
                    if idx == len(past_key_values[j]):
263
                        past_key_values[j].append(t.new_zeros(padded_t_shape))
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

                    past_key_values[j][idx][
                        start_index:end_index, :, -batch.max_input_length :, :
                    ] = t[:, :, -batch.max_input_length :, :]

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_last_hidden_state=encoder_last_hidden_state,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            decoder_input_lengths=decoder_input_lengths,
OlivierDehaene's avatar
OlivierDehaene committed
282
            decoder_logprobs=decoder_logprobs,
283
284
285
286
287
288
289
290
291
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_input_length=max_input_length,
            max_decoder_input_length=max_decoder_input_length,
        )


class Seq2SeqLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
292
    def __init__(self, model_name: str, quantize=False):
293
294
295
296
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
297
298
299
            if quantize:
                raise ValueError("quantization is not available on CPU")

300
301
302
303
304
305
306
            device = torch.device("cpu")
            dtype = torch.float32

        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
307
            load_in_8bit=quantize,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        ).eval()
        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        tokenizer.bos_token_id = self.model.config.decoder_start_token_id

        super(Seq2SeqLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[Seq2SeqLMBatch]:
        return Seq2SeqLMBatch

    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask: Optional,
        encoder_last_hidden_state: Optional,
        past_key_values: Optional = None,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
    ]:
        # Model Forward
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)

OlivierDehaene's avatar
OlivierDehaene committed
338
339
340
341
342
        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
        # internally...
        if encoder_last_hidden_state is not None:
            encoder_last_hidden_state = [encoder_last_hidden_state]

343
344
345
346
347
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
OlivierDehaene's avatar
OlivierDehaene committed
348
            encoder_outputs=encoder_last_hidden_state,
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            past_key_values=past_key_values,
            use_cache=True,
        )
        return (
            outputs.logits,
            outputs.encoder_last_hidden_state,
            outputs.past_key_values,
        )

    def generate_token(
        self, batch: Seq2SeqLMBatch
    ) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
            logits, encoder_last_hidden_state, past = self.forward(
                batch.input_ids,
                batch.attention_mask,
                batch.decoder_input_ids,
                batch.decoder_attention_mask,
                batch.encoder_last_hidden_state,
                batch.past_key_values,
            )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
378
        # New values for next forward
379
380
381
        next_batch_input_lengths = []
        next_batch_decoder_input_ids = []
        next_batch_decoder_input_lengths = []
OlivierDehaene's avatar
OlivierDehaene committed
382
        next_batch_decoder_logprobs = []
383

OlivierDehaene's avatar
OlivierDehaene committed
384
        # Metadata
385
386
387
388
389
390
391
392
393
394
395
396
        next_batch_size = 0
        next_batch_max_input_length = 0
        next_batch_max_decoder_input_length = 0

        # Finished requests
        generated_texts: List[GeneratedText] = []

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.decoder_input_lengths,
OlivierDehaene's avatar
OlivierDehaene committed
397
            batch.decoder_logprobs,
398
399
400
401
402
403
404
405
406
407
408
409
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.input_ids,
            batch.decoder_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            decoder_input_length,
OlivierDehaene's avatar
OlivierDehaene committed
410
            decoder_logprobs,
411
412
413
414
            logits,
            next_token_chooser,
            stopping_criteria,
            input_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
415
            decoder_input_ids,
416
417
        ) in enumerate(iterator):
            # Select next token
OlivierDehaene's avatar
OlivierDehaene committed
418
            next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
419
420

            # Append next token to decoder tokens
OlivierDehaene's avatar
OlivierDehaene committed
421
422
423
424
425
426
427
428
            decoder_input_ids = torch.cat([decoder_input_ids, next_token])
            new_decoder_input_length = decoder_input_length + 1

            next_token_logprob = logprobs[-1, next_token]
            if decoder_logprobs is None:
                decoder_logprobs = next_token_logprob
            else:
                decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
429
430

            # Evaluate stopping criteria
431
432
433
434
435
436
            stop, reason = stopping_criteria(
                next_token.squeeze(),
                self.tokenizer.decode(
                    next_token.squeeze(), clean_up_tokenization_spaces=False
                ),
            )
437
            if stop:
OlivierDehaene's avatar
OlivierDehaene committed
438
439
440
441
442
443
444
445
446
                # Slice with decoder_input_length to remove padding
                # Decode all tokens
                token_ids = decoder_input_ids[-new_decoder_input_length:]
                output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
                tokens = self.tokenizer.batch_decode(token_ids)
                # Add NaN for the bos token
                logprobs = [float("nan")] + decoder_logprobs[
                    -new_decoder_input_length:
                ].tolist()
447
448
                # Add to the list of finished generations with the original request
                generated_texts.append(
449
                    GeneratedText(
OlivierDehaene's avatar
OlivierDehaene committed
450
451
452
453
454
455
456
                        request=request,
                        output_text=output_text,
                        generated_tokens=stopping_criteria.current_tokens,
                        tokens=tokens,
                        token_ids=token_ids.tolist(),
                        logprobs=logprobs,
                        reason=reason,
457
                    )
458
459
460
461
                )
            # add to the next batch
            else:
                next_batch_keep_indices.append(i)
OlivierDehaene's avatar
OlivierDehaene committed
462
                next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
463
464
465
                next_batch_size += 1
                next_batch_input_lengths.append(input_length)
                next_batch_decoder_input_lengths.append(new_decoder_input_length)
OlivierDehaene's avatar
OlivierDehaene committed
466
                next_batch_decoder_logprobs.append(decoder_logprobs)
467
468
469
470
471
472
473
474
475
476
477
478
                next_batch_max_input_length = max(
                    next_batch_max_input_length, input_length
                )
                next_batch_max_decoder_input_length = max(
                    next_batch_max_decoder_input_length, new_decoder_input_length
                )

        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None

        next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
479
480
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
481
        if generated_texts:
OlivierDehaene's avatar
OlivierDehaene committed
482
            # Apply indices to attention mask, past key values and other items that need to be cached
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]

            if batch.decoder_attention_mask is not None:
                next_batch_decoder_attention_mask = batch.decoder_attention_mask[
                    next_batch_keep_indices
                ]
            else:
                next_batch_decoder_attention_mask = None

            next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
                next_batch_keep_indices
            ]

            next_batch_past_key_values = [
                [t[next_batch_keep_indices] for t in layer] for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
            next_batch_input_ids = batch.input_ids
            next_batch_attention_mask = batch.attention_mask
            next_batch_decoder_attention_mask = batch.decoder_attention_mask
            next_batch_encoder_last_hidden_state = encoder_last_hidden_state
            next_batch_past_key_values = past

            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

OlivierDehaene's avatar
OlivierDehaene committed
518
        # Update decoder_attention_mask with padding as we added a new token to input_ids
519
520
521
522
        if next_batch_decoder_attention_mask is not None:
            next_batch_decoder_attention_mask = torch.cat(
                [
                    next_batch_decoder_attention_mask,
523
                    next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                ],
                dim=1,
            )

        next_batch = Seq2SeqLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
            attention_mask=next_batch_attention_mask,
            decoder_input_ids=next_batch_decoder_input_ids,
            decoder_attention_mask=next_batch_decoder_attention_mask,
            encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
            past_key_values=next_batch_past_key_values,
            input_lengths=next_batch_input_lengths,
            decoder_input_lengths=next_batch_decoder_input_lengths,
OlivierDehaene's avatar
OlivierDehaene committed
539
            decoder_logprobs=next_batch_decoder_logprobs,
540
541
542
543
544
545
546
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_input_length=next_batch_max_input_length,
            max_decoder_input_length=next_batch_max_decoder_input_length,
        )
        return generated_texts, next_batch