causal_lm.py 22.2 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Past metadata
    keys_head_dim_last: bool = True

52
    def to_pb(self) -> generate_pb2.Batch:
53
54
55
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
56
            size=len(self),
57
58
59
60
        )

    @classmethod
    def from_pb(
61
62
63
64
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
69
70
        offsets = []
        token_offsets = []
71
        requests_idx_mapping = {}
72
73

        # Parse batch
74
        max_truncation = 0
75
        padding_right_offset = 0
76
77
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
78
            inputs.append(r.inputs)
79
80
            offsets.append(None)
            token_offsets.append(None)
81
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
82
83
84
85
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
86
            max_truncation = max(max_truncation, r.truncate)
87
88
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
89
90
            )

OlivierDehaene's avatar
OlivierDehaene committed
91
        tokenized_inputs = tokenizer(
92
93
94
            inputs,
            return_tensors="pt",
            padding=True,
95
            return_token_type_ids=False,
96
97
            truncation=True,
            max_length=max_truncation,
98
        ).to(device)
99

100
101
102
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

103
104
105
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
106
            (pb.size, max_input_length + padding_right_offset)
107
108
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
109
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
110

111
112
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
113
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
114
115
116
117

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
118
            requests_idx_mapping=requests_idx_mapping,
119
120
            input_ids=input_ids,
            attention_mask=attention_mask,
121
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
122
            past_key_values=None,
123
            all_input_ids=list(all_input_ids),
124
            input_lengths=input_lengths.tolist(),
125
126
            offsets=offsets,
            token_offsets=token_offsets,
127
128
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
129
            max_input_length=max_input_length.item(),
130
            padding_right_offset=padding_right_offset,
131
132
        )

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        offsets = []
        token_offsets = []
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

153
154
        new_padding_right_offset = 0

155
156
157
158
159
160
161
162
163
164
165
166
167
168
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
169
170
171
172
173
174
175
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)

            new_padding_right_offset = max(
                new_padding_right_offset,
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
176
177
178
179

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
180
181
182
183
        self.attention_mask = self.attention_mask[
            keep_indices,
            -(self.padding_right_offset + max_input_length):
            (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
184
185
        ]

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
        self.offsets = offsets
        self.token_offsets = token_offsets
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset

        return self
220

221
    @classmethod
222
    @tracer.start_as_current_span("concatenate")
223
224
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
225
        total_batch_size = 0
226
        max_input_length = 0
227
228
        padding_right_offset = 0
        for batch in batches:
229
            total_batch_size += len(batch)
230
            max_input_length = max(max_input_length, batch.max_input_length)
231
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
232
233
234

        # Batch attributes
        requests = []
235
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
236
        input_lengths = []
237
238
        offsets = []
        token_offsets = []
239
240
241
242
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
243
244
245
        # Batch tensors
        input_ids = None
        attention_mask = None
246
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
247
248
        past_key_values = []

249
250
251
252
253
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
254
            input_lengths.extend(batch.input_lengths)
255
256
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
257
258
259
260
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

261
262
263
264
265
266
267
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

268
            # Slicing end index for this batch
269
            end_index = start_index + len(batch)
270
271

            # We only concatenate batches that did at least one step
272
273
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
274

OlivierDehaene's avatar
OlivierDehaene committed
275
276
277
278
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
279
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
280
281
282
283
284
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
285
                attention_mask = batch.attention_mask.new_zeros(
286
                    (total_batch_size, max_input_length + padding_right_offset),
287
288
289
                )

            # We need to slice the attention mask to remove padding from previous steps
290
            # and to remove unused allocated space
291
            left_offset = max_input_length - batch.max_input_length
292
            batch_left_offset = (
293
                batch.attention_mask.shape[1]
294
                - batch.max_input_length
295
                - batch.padding_right_offset
296
            )
OlivierDehaene's avatar
OlivierDehaene committed
297
            attention_mask[
298
299
300
301
302
303
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
304

305
306
307
308
309
310
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
                ]
            elif batch.past_key_values[0][0].shape == 3:
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

            start_index = end_index

        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
335

336
337
338
339
340
341
342
343
344
345
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
346

347
348
349
350
351
352
353
354
355
356
357
358
359
360
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
361
                if batch.keys_head_dim_last:
362
363
364
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
365
                else:
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

            padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
393
394
395
396

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
397
            requests_idx_mapping=requests_idx_mapping,
398
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
399
            attention_mask=attention_mask,
400
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
401
            past_key_values=past_key_values,
402
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
403
            input_lengths=input_lengths,
404
405
            offsets=offsets,
            token_offsets=token_offsets,
406
407
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
408
            max_input_length=max_input_length,
409
            padding_right_offset=padding_right_offset,
410
            keys_head_dim_last=batches[0].keys_head_dim_last,
411
        )
412

413
414
415
    def __len__(self):
        return len(self.requests)

416
417

class CausalLM(Model):
418
419
420
421
422
423
424
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
425
426
427
428
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
429
430
431
            if quantize:
                raise ValueError("quantization is not available on CPU")

432
433
434
            device = torch.device("cpu")
            dtype = torch.float32

435
        tokenizer = AutoTokenizer.from_pretrained(
436
            model_id, revision=revision, padding_side="left", truncation_side="left"
437
        )
438
        self.model = AutoModelForCausalLM.from_pretrained(
439
            model_id,
440
            revision=revision,
441
442
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
443
            load_in_8bit=quantize,
444
        ).eval()
445
446
447
448
449
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
450

451
        super(CausalLM, self).__init__(
452
453
454
455
456
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
457
458
459
460
461
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
462

463
464
465
466
467
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

468
    def forward(
469
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
470
471
472
473
474
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
475
            position_ids=position_ids,
476
477
478
479
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
480

481
    @tracer.start_as_current_span("generate_token")
482
483
    def generate_token(
        self, batch: CausalLMBatch
484
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
485
486
487
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

488
489
        logits, past = self.forward(
            batch.input_ids,
490
            attention_mask,
491
492
            batch.position_ids,
            batch.past_key_values,
493
494
        )

495
496
        # Results
        generations: List[Generation] = []
497
        stopped = True
498
499
500
501

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
502
            batch.input_lengths,
503
504
            batch.offsets,
            batch.token_offsets,
505
506
507
508
509
510
511
512
513
514
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
515
516
            offset,
            token_offset,
517
518
519
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
520
            all_input_ids,
521
522
        ) in enumerate(iterator):
            # Select next token
523
524
525
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
526
527

            # Append next token to all tokens
528
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
529
530
            new_input_length = input_length + 1

531
532
533
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
534
535
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
536
            )
537
538

            # Evaluate stopping criteria
539
            stop, reason = stopping_criteria(
540
541
                next_token_id_squeezed,
                next_token_text,
542
            )
543

544
            if stop:
545
                # Decode generated tokens
546
                output_text = self.decode(
547
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
548
                )
549
550
551
552
553
554
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

555
556
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
557
558
                )
            else:
559
560
                # Keep request in the batch
                generated_text = None
561
                stopped = False
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
587
                next_token_id_squeezed.item() in self.all_special_ids,
588
589
590
591
592
                generated_text,
            )

            generations.append(generation)

593
594
595
596
597
598
599
600
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, new_input_length)

601
        # We finished all generations in the batch; there is no next batch
602
        if stopped:
603
            return generations, None
604

605
606
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
607

608
        # Update attention_mask as we added a new token to input_ids
609
610
611
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
612

613
        # Update position_ids
614
615
616
617
618
619
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch