causal_lm.py 20.1 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Past metadata
    keys_head_dim_last: bool = True

52
    def to_pb(self) -> generate_pb2.Batch:
53
54
55
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
56
            size=len(self),
57
58
59
60
        )

    @classmethod
    def from_pb(
61
62
63
64
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
69
70
        offsets = []
        token_offsets = []
71
        requests_idx_mapping = {}
72
73

        # Parse batch
74
        max_truncation = 0
75
        padding_right_offset = 0
76
77
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
78
            inputs.append(r.inputs)
79
80
            offsets.append(None)
            token_offsets.append(None)
81
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
82
83
84
85
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
86
            max_truncation = max(max_truncation, r.truncate)
87
88
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
89
90
            )

OlivierDehaene's avatar
OlivierDehaene committed
91
        tokenized_inputs = tokenizer(
92
93
94
            inputs,
            return_tensors="pt",
            padding=True,
95
            return_token_type_ids=False,
96
97
            truncation=True,
            max_length=max_truncation,
98
        ).to(device)
99

100
101
102
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

103
104
105
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
106
            (pb.size, max_input_length + padding_right_offset)
107
108
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
109
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
110

111
112
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
113
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
114
115
116
117

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
118
            requests_idx_mapping=requests_idx_mapping,
119
120
            input_ids=input_ids,
            attention_mask=attention_mask,
121
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
122
            past_key_values=None,
123
            all_input_ids=list(all_input_ids),
124
            input_lengths=input_lengths.tolist(),
125
126
            offsets=offsets,
            token_offsets=token_offsets,
127
128
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
129
            max_input_length=max_input_length.item(),
130
            padding_right_offset=padding_right_offset,
131
132
        )

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        offsets = []
        token_offsets = []
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        attention_mask = self.attention_mask[keep_indices]
        position_ids = self.position_ids[keep_indices]
        # Force past to be of dim [self_size, num_heads, ...] for easy indexing
        past_key_values = [
            [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
            for layer in self.past_key_values
        ]

        return CausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            all_input_ids=all_input_ids,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            padding_right_offset=self.padding_right_offset,
            keys_head_dim_last=self.keys_head_dim_last,
        )

198
    @classmethod
199
    @tracer.start_as_current_span("concatenate")
200
201
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
202
        total_batch_size = 0
203
        max_input_length = 0
204
205
        padding_right_offset = 0
        for batch in batches:
206
            total_batch_size += len(batch)
207
            max_input_length = max(max_input_length, batch.max_input_length)
208
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
209
210
211

        # Batch attributes
        requests = []
212
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
213
        input_lengths = []
214
215
        offsets = []
        token_offsets = []
216
217
218
219
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
220
221
222
        # Batch tensors
        input_ids = None
        attention_mask = None
223
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
224
225
        past_key_values = []

226
227
228
229
230
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
231
            input_lengths.extend(batch.input_lengths)
232
233
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
234
235
236
237
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

238
239
240
241
242
243
244
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

245
            # Slicing end index for this batch
246
            end_index = start_index + len(batch)
247
248

            # We only concatenate batches that did at least one step
249
250
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
251

OlivierDehaene's avatar
OlivierDehaene committed
252
253
254
255
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
256
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
260
261
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
262
                attention_mask = batch.attention_mask.new_zeros(
263
                    (total_batch_size, max_input_length + padding_right_offset),
264
265
266
                )

            # We need to slice the attention mask to remove padding from previous steps
267
            # and to remove unused allocated space
268
            left_offset = max_input_length - batch.max_input_length
269
            batch_left_offset = (
270
                batch.attention_mask.shape[1]
271
                - batch.max_input_length
272
                - batch.padding_right_offset
273
            )
OlivierDehaene's avatar
OlivierDehaene committed
274
            attention_mask[
275
276
277
278
279
280
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
281

282
283
284
285
286
287
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
288
            for j, past in enumerate(batch.past_key_values):
289
290
                past_keys, past_values = past

291
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
292
293
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
294
295
                past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
296

297
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
298

299
                padded_past_values_shape = (
300
301
                    total_batch_size,
                    num_heads,
302
                    max_input_length - 1,
303
                    head_dim,
304
305
                )

306
307
308
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
309
                    # seq_length is last for BLOOM
310
                    padded_past_keys_shape = (
311
312
313
                        total_batch_size,
                        num_heads,
                        head_dim,
314
                        max_input_length - 1,
315
316
                    )

317
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
318
                if j == len(past_key_values):
319
320
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
321
322
323
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
324
                if batch.keys_head_dim_last:
325
                    past_key_values[j][0][
326
327
                        start_index:end_index,
                        :,
328
                        -(batch.max_input_length - 1) :,
329
                        :,
330
                    ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
331
                else:
332
                    past_key_values[j][0][
333
334
335
                        start_index:end_index,
                        :,
                        :,
336
337
                        -(batch.max_input_length - 1) :,
                    ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
338
339

                past_key_values[j][1][
340
341
                    start_index:end_index, :, -(batch.max_input_length - 1) :, :
                ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
342

343
            start_index += len(batch)
344
345
346
347

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
348
            requests_idx_mapping=requests_idx_mapping,
349
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
350
            attention_mask=attention_mask,
351
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
352
            past_key_values=past_key_values,
353
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
354
            input_lengths=input_lengths,
355
356
            offsets=offsets,
            token_offsets=token_offsets,
357
358
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
359
            max_input_length=max_input_length,
360
            padding_right_offset=padding_right_offset,
361
            keys_head_dim_last=batches[0].keys_head_dim_last,
362
        )
363

364
365
366
    def __len__(self):
        return len(self.requests)

367
368

class CausalLM(Model):
369
370
371
372
373
374
375
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
376
377
378
379
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
380
381
382
            if quantize:
                raise ValueError("quantization is not available on CPU")

383
384
385
            device = torch.device("cpu")
            dtype = torch.float32

386
        tokenizer = AutoTokenizer.from_pretrained(
387
            model_id, revision=revision, padding_side="left", truncation_side="left"
388
        )
389
        self.model = AutoModelForCausalLM.from_pretrained(
390
            model_id,
391
            revision=revision,
392
393
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
394
            load_in_8bit=quantize,
395
        ).eval()
396
397
398
399
400
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
401

402
        super(CausalLM, self).__init__(
403
            tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
404
405
406
407
408
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
409

410
411
412
413
414
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

415
    def forward(
416
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
417
418
419
420
421
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
422
            position_ids=position_ids,
423
424
425
426
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
427

428
    @tracer.start_as_current_span("generate_token")
429
430
    def generate_token(
        self, batch: CausalLMBatch
431
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
432
433
434
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

435
436
        logits, past = self.forward(
            batch.input_ids,
437
            attention_mask,
438
439
            batch.position_ids,
            batch.past_key_values,
440
441
        )

442
443
        # Results
        generations: List[Generation] = []
444
        stopped = True
445
446
447
448

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
449
            batch.input_lengths,
450
451
            batch.offsets,
            batch.token_offsets,
452
453
454
455
456
457
458
459
460
461
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
462
463
            offset,
            token_offset,
464
465
466
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
467
            all_input_ids,
468
469
        ) in enumerate(iterator):
            # Select next token
470
471
472
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
473
474

            # Append next token to all tokens
475
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
476
477
            new_input_length = input_length + 1

478
479
480
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
481
482
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
483
            )
484
485

            # Evaluate stopping criteria
486
            stop, reason = stopping_criteria(
487
488
                next_token_id_squeezed,
                next_token_text,
489
            )
490

491
            if stop:
492
                # Decode generated tokens
493
                output_text = self.decode(
494
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
495
                )
496
497
498
499
500
501
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

502
503
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
504
505
                )
            else:
506
507
                # Keep request in the batch
                generated_text = None
508
                stopped = False
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
534
                next_token_id_squeezed.item() in self.all_special_ids,
535
536
537
538
539
                generated_text,
            )

            generations.append(generation)

540
541
542
543
544
545
546
547
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, new_input_length)

548
        # We finished all generations in the batch; there is no next batch
549
        if stopped:
550
            return generations, None
551

552
553
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
554

555
        # Update attention_mask as we added a new token to input_ids
556
557
558
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
559

560
        # Update position_ids
561
562
563
564
565
566
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch