causal_lm.py 23.5 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Maximum number of tokens this batch will grow to
    max_tokens: int

52
53
54
    # Past metadata
    keys_head_dim_last: bool = True

55
    def to_pb(self) -> generate_pb2.Batch:
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
59
            size=len(self),
60
            max_tokens=self.max_tokens,
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
70
71
72
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
73
74
        offsets = []
        token_offsets = []
75
        requests_idx_mapping = {}
76
77

        # Parse batch
78
        max_truncation = 0
79
        padding_right_offset = 0
80
        max_decode_tokens = 0
81
82
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
83
            inputs.append(r.inputs)
84
85
            offsets.append(None)
            token_offsets.append(None)
86
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
87
88
89
90
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
91
            max_truncation = max(max_truncation, r.truncate)
92
            max_decode_tokens += stopping_criteria.max_new_tokens
93
94
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
95
96
            )

OlivierDehaene's avatar
OlivierDehaene committed
97
        tokenized_inputs = tokenizer(
98
99
100
            inputs,
            return_tensors="pt",
            padding=True,
101
            return_token_type_ids=False,
102
103
            truncation=True,
            max_length=max_truncation,
104
        ).to(device)
105

106
107
108
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

109
110
111
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
112
            (pb.size, max_input_length + padding_right_offset)
113
114
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
115
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
116

117
118
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
119
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
120

121
122
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

123
124
125
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
126
            requests_idx_mapping=requests_idx_mapping,
127
128
            input_ids=input_ids,
            attention_mask=attention_mask,
129
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
130
            past_key_values=None,
131
            all_input_ids=list(all_input_ids),
132
            input_lengths=input_lengths.tolist(),
133
134
            offsets=offsets,
            token_offsets=token_offsets,
135
136
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
137
            max_input_length=max_input_length.item(),
138
            padding_right_offset=padding_right_offset,
139
            max_tokens=max_tokens,
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        offsets = []
        token_offsets = []
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

162
        total_remaining_decode_tokens = 0
163
164
        new_padding_right_offset = 0

165
166
167
168
169
170
171
172
173
174
175
176
177
178
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
179
180
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
181
            remaining_decode_tokens = (
182
183
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
184
185
186
187
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
188
189
190
191

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
192
193
        self.attention_mask = self.attention_mask[
            keep_indices,
194
195
196
197
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
198
199
        ]

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

220
221
        max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens

222
223
224
225
226
227
228
229
230
231
232
233
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
234
        self.max_tokens = max_tokens
235
236

        return self
237

238
    @classmethod
239
    @tracer.start_as_current_span("concatenate")
240
241
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
242
        total_batch_size = 0
243
        max_input_length = 0
244
245
        padding_right_offset = 0
        for batch in batches:
246
            total_batch_size += len(batch)
247
            max_input_length = max(max_input_length, batch.max_input_length)
248
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
249
250
251

        # Batch attributes
        requests = []
252
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
253
        input_lengths = []
254
255
        offsets = []
        token_offsets = []
256
257
258
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
259
        max_tokens = 0
260

OlivierDehaene's avatar
OlivierDehaene committed
261
262
263
        # Batch tensors
        input_ids = None
        attention_mask = None
264
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
265
266
        past_key_values = []

267
268
269
270
271
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
272
            input_lengths.extend(batch.input_lengths)
273
274
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
275
276
277
278
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

279
280
281
282
283
284
285
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

286
            # Slicing end index for this batch
287
            end_index = start_index + len(batch)
288
289

            # We only concatenate batches that did at least one step
290
291
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
292

OlivierDehaene's avatar
OlivierDehaene committed
293
294
295
296
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
297
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
298
299
300
301
302
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
303
                attention_mask = batch.attention_mask.new_zeros(
304
                    (total_batch_size, max_input_length + padding_right_offset),
305
306
307
                )

            # We need to slice the attention mask to remove padding from previous steps
308
            # and to remove unused allocated space
309
            left_offset = max_input_length - batch.max_input_length
310
            batch_left_offset = (
311
                batch.attention_mask.shape[1]
312
                - batch.max_input_length
313
                - batch.padding_right_offset
314
            )
OlivierDehaene's avatar
OlivierDehaene committed
315
            attention_mask[
316
317
318
319
320
321
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
322

323
324
325
326
327
328
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

329
330
331
332
333
334
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
335
336
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
337
                ]
338
            elif len(batch.past_key_values[0][0].shape) == 3:
339
340
341
342
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

343
344
345
346
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
347

348
349
            start_index = end_index

350
351
352
353
354
355
356
357
358
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
359

360
361
362
363
364
365
366
367
368
369
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
385
                if batch.keys_head_dim_last:
386
387
388
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
389
                else:
390
391
392
393
394
395
396
397
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

398
399
400
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

416
                # Update values
417
418
419
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
420
421
422
423

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
424
            requests_idx_mapping=requests_idx_mapping,
425
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
426
            attention_mask=attention_mask,
427
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
428
            past_key_values=past_key_values,
429
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
430
            input_lengths=input_lengths,
431
432
            offsets=offsets,
            token_offsets=token_offsets,
433
434
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
435
            max_input_length=max_input_length,
436
            padding_right_offset=padding_right_offset,
437
            keys_head_dim_last=batches[0].keys_head_dim_last,
438
            max_tokens=max_tokens,
439
        )
440

441
442
443
    def __len__(self):
        return len(self.requests)

444
445

class CausalLM(Model):
446
447
448
449
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
450
        quantize: Optional[str] = None,
451
452
        decode_buffer: int = 3,
    ):
453
454
        if torch.cuda.is_available():
            device = torch.device("cuda")
455
            dtype = torch.float16
456
        else:
457
458
459
            if quantize:
                raise ValueError("quantization is not available on CPU")

460
461
462
            device = torch.device("cpu")
            dtype = torch.float32

463
        tokenizer = AutoTokenizer.from_pretrained(
464
            model_id, revision=revision, padding_side="left", truncation_side="left"
465
        )
466
        self.model = AutoModelForCausalLM.from_pretrained(
467
            model_id,
468
            revision=revision,
469
470
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
471
            load_in_8bit=quantize == "bitsandbytes",
472
        ).eval()
473
474
475
476
477
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
478

479
        super(CausalLM, self).__init__(
480
481
482
483
484
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
485
486
487
488
489
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
490

491
492
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
493
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
494
495
        )

496
    def forward(
497
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
498
499
500
501
502
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
503
            position_ids=position_ids,
504
505
506
507
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
508

509
    @tracer.start_as_current_span("generate_token")
510
511
    def generate_token(
        self, batch: CausalLMBatch
512
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
513
514
515
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

516
517
        logits, past = self.forward(
            batch.input_ids,
518
            attention_mask,
519
520
            batch.position_ids,
            batch.past_key_values,
521
522
        )

523
524
        # Results
        generations: List[Generation] = []
525
        stopped = True
526
527
528
529

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
530
            batch.input_lengths,
531
532
            batch.offsets,
            batch.token_offsets,
533
534
535
536
537
538
539
540
541
542
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
543
544
            offset,
            token_offset,
545
546
547
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
548
            all_input_ids,
549
550
        ) in enumerate(iterator):
            # Select next token
551
            next_token_id, logprobs = next_token_chooser(
552
                all_input_ids.view(1, -1), logits[-1:, :]
553
            )
554
555

            # Append next token to all tokens
556
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
557
558
            new_input_length = input_length + 1

559
560
561
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
562
563
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
564
            )
565
566

            # Evaluate stopping criteria
567
            stop, reason = stopping_criteria(
568
569
                next_token_id_squeezed,
                next_token_text,
570
            )
571

572
            if not stop:
573
                stopped = False
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :, 0]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
623
624
                )

625
                generations.append(generation)
626

627
628
629
630
631
632
633
634
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, new_input_length)

635
        # We finished all generations in the batch; there is no next batch
636
        if stopped:
637
            return generations, None
638

639
640
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
641

642
        # Update attention_mask as we added a new token to input_ids
643
644
645
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
646

647
        # Update position_ids
648
649
650
651
652
653
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch