causal_lm.py 23.2 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Maximum number of tokens this batch will grow to
    max_tokens: int

52
53
54
    # Past metadata
    keys_head_dim_last: bool = True

55
    def to_pb(self) -> generate_pb2.Batch:
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
59
            size=len(self),
60
            max_tokens=self.max_tokens,
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
70
71
72
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
73
74
        offsets = []
        token_offsets = []
75
        requests_idx_mapping = {}
76
77

        # Parse batch
78
        max_truncation = 0
79
        padding_right_offset = 0
80
        max_decode_tokens = 0
81
82
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
83
            inputs.append(r.inputs)
84
85
            offsets.append(None)
            token_offsets.append(None)
86
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
87
88
89
90
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
91
            max_truncation = max(max_truncation, r.truncate)
92
            max_decode_tokens += stopping_criteria.max_new_tokens
93
94
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
95
96
            )

OlivierDehaene's avatar
OlivierDehaene committed
97
        tokenized_inputs = tokenizer(
98
99
100
            inputs,
            return_tensors="pt",
            padding=True,
101
            return_token_type_ids=False,
102
103
            truncation=True,
            max_length=max_truncation,
104
        ).to(device)
105

106
107
108
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

109
110
111
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
112
            (pb.size, max_input_length + padding_right_offset)
113
114
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
115
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
116

117
118
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
119
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
120

121
122
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

123
124
125
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
126
            requests_idx_mapping=requests_idx_mapping,
127
128
            input_ids=input_ids,
            attention_mask=attention_mask,
129
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
130
            past_key_values=None,
131
            all_input_ids=list(all_input_ids),
132
            input_lengths=input_lengths.tolist(),
133
134
            offsets=offsets,
            token_offsets=token_offsets,
135
136
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
137
            max_input_length=max_input_length.item(),
138
            padding_right_offset=padding_right_offset,
139
            max_tokens=max_tokens,
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        offsets = []
        token_offsets = []
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

162
        total_remaining_decode_tokens = 0
163
164
        new_padding_right_offset = 0

165
166
167
168
169
170
171
172
173
174
175
176
177
178
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
179
180
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
181
            remaining_decode_tokens = (
182
183
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
184
185
186
187
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
188
189
190
191

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
192
193
        self.attention_mask = self.attention_mask[
            keep_indices,
194
195
196
197
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
198
199
        ]

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

220
221
        max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens

222
223
224
225
226
227
228
229
230
231
232
233
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
234
        self.max_tokens = max_tokens
235
236

        return self
237

238
    @classmethod
239
    @tracer.start_as_current_span("concatenate")
240
241
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
242
        total_batch_size = 0
243
        max_input_length = 0
244
245
        padding_right_offset = 0
        for batch in batches:
246
            total_batch_size += len(batch)
247
            max_input_length = max(max_input_length, batch.max_input_length)
248
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
249
250
251

        # Batch attributes
        requests = []
252
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
253
        input_lengths = []
254
255
        offsets = []
        token_offsets = []
256
257
258
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
259
        max_tokens = 0
260

OlivierDehaene's avatar
OlivierDehaene committed
261
262
263
        # Batch tensors
        input_ids = None
        attention_mask = None
264
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
265
266
        past_key_values = []

267
268
269
270
271
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
272
            input_lengths.extend(batch.input_lengths)
273
274
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
275
276
277
278
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

279
280
281
282
283
284
285
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

286
            # Slicing end index for this batch
287
            end_index = start_index + len(batch)
288
289

            # We only concatenate batches that did at least one step
290
291
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
292

OlivierDehaene's avatar
OlivierDehaene committed
293
294
295
296
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
297
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
298
299
300
301
302
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
303
                attention_mask = batch.attention_mask.new_zeros(
304
                    (total_batch_size, max_input_length + padding_right_offset),
305
306
307
                )

            # We need to slice the attention mask to remove padding from previous steps
308
            # and to remove unused allocated space
309
            left_offset = max_input_length - batch.max_input_length
310
            batch_left_offset = (
311
                batch.attention_mask.shape[1]
312
                - batch.max_input_length
313
                - batch.padding_right_offset
314
            )
OlivierDehaene's avatar
OlivierDehaene committed
315
            attention_mask[
316
317
318
319
320
321
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
322

323
324
325
326
327
328
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

329
330
331
332
333
334
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
335
336
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
337
                ]
338
            elif len(batch.past_key_values[0][0].shape) == 3:
339
340
341
342
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

343
344
345
346
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
347

348
349
            start_index = end_index

350
351
352
353
354
355
356
357
358
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
359

360
361
362
363
364
365
366
367
368
369
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
385
                if batch.keys_head_dim_last:
386
387
388
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
389
                else:
390
391
392
393
394
395
396
397
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

398
399
400
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

416
                # Update values
417
418
419
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
420
421
422
423

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
424
            requests_idx_mapping=requests_idx_mapping,
425
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
426
            attention_mask=attention_mask,
427
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
428
            past_key_values=past_key_values,
429
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
430
            input_lengths=input_lengths,
431
432
            offsets=offsets,
            token_offsets=token_offsets,
433
434
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
435
            max_input_length=max_input_length,
436
            padding_right_offset=padding_right_offset,
437
            keys_head_dim_last=batches[0].keys_head_dim_last,
438
            max_tokens=max_tokens,
439
        )
440

441
442
443
    def __len__(self):
        return len(self.requests)

444
445

class CausalLM(Model):
446
447
448
449
450
451
452
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
453
454
455
456
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
457
458
459
            if quantize:
                raise ValueError("quantization is not available on CPU")

460
461
462
            device = torch.device("cpu")
            dtype = torch.float32

463
        tokenizer = AutoTokenizer.from_pretrained(
464
            model_id, revision=revision, padding_side="left", truncation_side="left"
465
        )
466
        self.model = AutoModelForCausalLM.from_pretrained(
467
            model_id,
468
            revision=revision,
469
470
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
471
            load_in_8bit=quantize,
472
        ).eval()
473
474
475
476
477
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
478

479
        super(CausalLM, self).__init__(
480
481
482
483
484
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
485
486
487
488
489
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
490

491
492
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
493
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
494
495
        )

496
    def forward(
497
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
498
499
500
501
502
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
503
            position_ids=position_ids,
504
505
506
507
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
508

509
    @tracer.start_as_current_span("generate_token")
510
511
    def generate_token(
        self, batch: CausalLMBatch
512
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
513
514
515
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

516
517
        logits, past = self.forward(
            batch.input_ids,
518
            attention_mask,
519
520
            batch.position_ids,
            batch.past_key_values,
521
522
        )

523
524
        # Results
        generations: List[Generation] = []
525
        stopped = True
526
527
528
529

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
530
            batch.input_lengths,
531
532
            batch.offsets,
            batch.token_offsets,
533
534
535
536
537
538
539
540
541
542
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
543
544
            offset,
            token_offset,
545
546
547
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
548
            all_input_ids,
549
550
        ) in enumerate(iterator):
            # Select next token
551
552
553
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
554
555

            # Append next token to all tokens
556
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
557
558
            new_input_length = input_length + 1

559
560
561
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
562
563
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
564
            )
565
566

            # Evaluate stopping criteria
567
            stop, reason = stopping_criteria(
568
569
                next_token_id_squeezed,
                next_token_text,
570
            )
571

572
            if stop:
573
                # Decode generated tokens
574
                output_text = self.decode(
575
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
576
                )
577
578
579
580
581
582
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

583
584
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
585
586
                )
            else:
587
588
                # Keep request in the batch
                generated_text = None
589
                stopped = False
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
615
                next_token_id_squeezed.item() in self.all_special_ids,
616
617
618
619
620
                generated_text,
            )

            generations.append(generation)

621
622
623
624
625
626
627
628
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, new_input_length)

629
        # We finished all generations in the batch; there is no next batch
630
        if stopped:
631
            return generations, None
632

633
634
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
635

636
        # Update attention_mask as we added a new token to input_ids
637
638
639
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
640

641
        # Update position_ids
642
643
644
645
646
647
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch