causal_lm.py 23.7 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Maximum number of tokens this batch will grow to
    max_tokens: int

52
53
54
    # Past metadata
    keys_head_dim_last: bool = True

55
    def to_pb(self) -> generate_pb2.Batch:
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
59
            size=len(self),
60
            max_tokens=self.max_tokens,
61
62
63
64
        )

    @classmethod
    def from_pb(
65
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
69
70
71
72
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
73
74
        prefix_offsets = []
        read_offsets = []
75
        requests_idx_mapping = {}
76
77

        # Parse batch
78
        max_truncation = 0
79
        padding_right_offset = 0
80
        max_decode_tokens = 0
81
82
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
83
            inputs.append(r.inputs)
84
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
85
86
87
88
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
89
            max_truncation = max(max_truncation, r.truncate)
90
            max_decode_tokens += stopping_criteria.max_new_tokens
91
92
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
93
94
            )

OlivierDehaene's avatar
OlivierDehaene committed
95
        tokenized_inputs = tokenizer(
96
97
98
            inputs,
            return_tensors="pt",
            padding=True,
99
            return_token_type_ids=False,
100
101
            truncation=True,
            max_length=max_truncation,
102
        ).to(device)
103
104
105
106
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
107

108
109
110
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

111
112
113
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
114
            (pb.size, max_input_length + padding_right_offset)
115
116
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
117
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
118

119
120
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
121
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
122

123
124
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

125
126
127
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
128
            requests_idx_mapping=requests_idx_mapping,
129
130
            input_ids=input_ids,
            attention_mask=attention_mask,
131
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
132
            past_key_values=None,
133
            all_input_ids=list(all_input_ids),
134
            input_lengths=input_lengths.tolist(),
135
136
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
137
138
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
139
            max_input_length=max_input_length.item(),
140
            padding_right_offset=padding_right_offset,
141
            max_tokens=max_tokens,
142
143
        )

144
145
146
147
148
149
150
151
152
153
154
155
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
156
157
        prefix_offsets = []
        read_offsets = []
158
159
160
161
162
163
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

164
        total_remaining_decode_tokens = 0
165
166
        new_padding_right_offset = 0

167
168
169
170
171
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

172
173
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
174
175
176
177
178
179
180
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
181
182
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
183
            remaining_decode_tokens = (
184
185
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
186
187
188
189
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
190
191
192
193

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
194
195
        self.attention_mask = self.attention_mask[
            keep_indices,
196
197
198
199
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
200
201
        ]

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

222
223
        max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens

224
225
226
227
228
229
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
230
231
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
232
233
234
235
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
236
        self.max_tokens = max_tokens
237
238

        return self
239

240
    @classmethod
241
    @tracer.start_as_current_span("concatenate")
242
243
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
244
        total_batch_size = 0
245
        max_input_length = 0
246
247
        padding_right_offset = 0
        for batch in batches:
248
            total_batch_size += len(batch)
249
            max_input_length = max(max_input_length, batch.max_input_length)
250
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
251
252
253

        # Batch attributes
        requests = []
254
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
255
        input_lengths = []
256
257
        prefix_offsets = []
        read_offsets = []
258
259
260
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
261
        max_tokens = 0
262

OlivierDehaene's avatar
OlivierDehaene committed
263
264
265
        # Batch tensors
        input_ids = None
        attention_mask = None
266
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
267
268
        past_key_values = []

269
270
271
272
273
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
274
            input_lengths.extend(batch.input_lengths)
275
276
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
277
278
279
280
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

281
282
283
284
285
286
287
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

288
            # Slicing end index for this batch
289
            end_index = start_index + len(batch)
290
291

            # We only concatenate batches that did at least one step
292
293
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
294

OlivierDehaene's avatar
OlivierDehaene committed
295
296
297
298
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
299
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
300
301
302
303
304
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
305
                attention_mask = batch.attention_mask.new_zeros(
306
                    (total_batch_size, max_input_length + padding_right_offset),
307
308
309
                )

            # We need to slice the attention mask to remove padding from previous steps
310
            # and to remove unused allocated space
311
            left_offset = max_input_length - batch.max_input_length
312
            batch_left_offset = (
313
                batch.attention_mask.shape[1]
314
                - batch.max_input_length
315
                - batch.padding_right_offset
316
            )
OlivierDehaene's avatar
OlivierDehaene committed
317
            attention_mask[
318
319
320
321
322
323
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
324

325
326
327
328
329
330
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

331
332
333
334
335
336
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
337
338
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
339
                ]
340
            elif len(batch.past_key_values[0][0].shape) == 3:
341
342
343
344
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

345
346
347
348
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
349

350
351
            start_index = end_index

352
353
354
355
356
357
358
359
360
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
361

362
363
364
365
366
367
368
369
370
371
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
387
                if batch.keys_head_dim_last:
388
389
390
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
391
                else:
392
393
394
395
396
397
398
399
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

400
401
402
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

418
                # Update values
419
420
421
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
422
423
424
425

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
426
            requests_idx_mapping=requests_idx_mapping,
427
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
428
            attention_mask=attention_mask,
429
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
430
            past_key_values=past_key_values,
431
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
432
            input_lengths=input_lengths,
433
434
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
435
436
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
437
            max_input_length=max_input_length,
438
            padding_right_offset=padding_right_offset,
439
            keys_head_dim_last=batches[0].keys_head_dim_last,
440
            max_tokens=max_tokens,
441
        )
442

443
444
445
    def __len__(self):
        return len(self.requests)

446
447

class CausalLM(Model):
448
449
450
451
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
452
        quantize: Optional[str] = None,
453
    ):
454
455
        if torch.cuda.is_available():
            device = torch.device("cuda")
456
            dtype = torch.float16
457
        else:
458
459
460
            if quantize:
                raise ValueError("quantization is not available on CPU")

461
462
463
            device = torch.device("cpu")
            dtype = torch.float32

464
        tokenizer = AutoTokenizer.from_pretrained(
465
            model_id, revision=revision, padding_side="left", truncation_side="left"
466
        )
467
        model = AutoModelForCausalLM.from_pretrained(
468
            model_id,
469
            revision=revision,
470
471
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
472
            load_in_8bit=quantize == "bitsandbytes",
473
        )
474
        tokenizer.pad_token_id = (
475
476
477
            model.config.pad_token_id
            if model.config.pad_token_id is not None
            else model.config.eos_token_id
478
        )
479

480
        super(CausalLM, self).__init__(
481
            model=model,
482
483
484
485
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
486
487
488
489
490
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
491

492
493
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
494
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
495
496
        )

497
    def forward(
498
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
499
500
501
502
503
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
504
            position_ids=position_ids,
505
506
            past_key_values=past_key_values,
            use_cache=True,
507
            return_dict=True,
508
509
        )
        return outputs.logits, outputs.past_key_values
510

511
    @tracer.start_as_current_span("generate_token")
512
513
    def generate_token(
        self, batch: CausalLMBatch
514
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
515
516
517
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

518
519
        logits, past = self.forward(
            batch.input_ids,
520
            attention_mask,
521
522
            batch.position_ids,
            batch.past_key_values,
523
524
        )

525
526
        # Results
        generations: List[Generation] = []
527
        stopped = True
528
529
530
531

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
532
            batch.input_lengths,
533
534
            batch.prefix_offsets,
            batch.read_offsets,
535
536
537
538
539
540
541
542
543
544
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
545
546
            prefix_offset,
            read_offset,
547
548
549
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
550
            all_input_ids,
551
552
        ) in enumerate(iterator):
            # Select next token
553
            next_token_id, logprobs = next_token_chooser(
554
                all_input_ids.view(1, -1), logits[-1:, :]
555
            )
556
557

            # Append next token to all tokens
558
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
559
560
            new_input_length = input_length + 1

561
562
563
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
564
565
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
566
            )
567
568

            # Evaluate stopping criteria
569
            stop, reason = stopping_criteria(
570
571
                next_token_id_squeezed,
                next_token_text,
572
            )
573

574
            if not stop:
575
                stopped = False
576

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :, 0]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
625
626
                )

627
                generations.append(generation)
628

629
630
631
632
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
633
634
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
635
636
            batch.max_input_length = max(batch.max_input_length, new_input_length)

637
        # We finished all generations in the batch; there is no next batch
638
        if stopped:
639
            return generations, None
640

641
642
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
643

644
        # Update attention_mask as we added a new token to input_ids
645
646
647
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
648

649
        # Update position_ids
650
651
652
653
654
655
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch