causal_lm.py 20.1 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
6
from typing import Optional, Tuple, List, Type, Dict
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
25
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
30
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
31
32
33
    past_key_values: Optional[List[Tuple]]

    # All tokens
34
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
35
36
37

    # Lengths of all generations present in the batch
    input_lengths: List[int]
38
39
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
40
41

    # Generation helpers
42
43
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
44
45

    # Metadata used for padding
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Past metadata
    keys_head_dim_last: bool = True

52
    def to_pb(self) -> generate_pb2.Batch:
53
54
55
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
56
            size=len(self),
57
58
59
60
        )

    @classmethod
    def from_pb(
61
62
63
64
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
69
70
        offsets = []
        token_offsets = []
71
        requests_idx_mapping = {}
72
73

        # Parse batch
74
        max_truncation = 0
75
        padding_right_offset = 0
76
77
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
78
            inputs.append(r.inputs)
79
80
            offsets.append(None)
            token_offsets.append(None)
81
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
82
83
84
85
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
86
            max_truncation = max(max_truncation, r.truncate)
87
88
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
89
90
            )

OlivierDehaene's avatar
OlivierDehaene committed
91
        tokenized_inputs = tokenizer(
92
93
94
            inputs,
            return_tensors="pt",
            padding=True,
95
            return_token_type_ids=False,
96
97
            truncation=True,
            max_length=max_truncation,
98
        ).to(device)
99

100
101
102
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

103
104
105
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
106
            (pb.size, max_input_length + padding_right_offset)
107
108
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
109
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
110

111
112
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
113
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
114
115
116
117

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
118
            requests_idx_mapping=requests_idx_mapping,
119
120
            input_ids=input_ids,
            attention_mask=attention_mask,
121
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
122
            past_key_values=None,
123
            all_input_ids=list(all_input_ids),
124
            input_lengths=input_lengths.tolist(),
125
126
            offsets=offsets,
            token_offsets=token_offsets,
127
128
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
129
            max_input_length=max_input_length.item(),
130
            padding_right_offset=padding_right_offset,
131
132
        )

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        if len(requests) == len(self):
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
        input_lengths = []
        offsets = []
        token_offsets = []
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i
            keep_indices.append(idx)

            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        attention_mask = self.attention_mask[keep_indices]
        position_ids = self.position_ids[keep_indices]
        # Force past to be of dim [self_size, num_heads, ...] for easy indexing
        past_key_values = [
            [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
            for layer in self.past_key_values
        ]

        return CausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            all_input_ids=all_input_ids,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            max_input_length=max_input_length,
            padding_right_offset=self.padding_right_offset,
            keys_head_dim_last=self.keys_head_dim_last,
        )

198
    @classmethod
199
    @tracer.start_as_current_span("concatenate")
200
201
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
202
        total_batch_size = 0
203
        max_input_length = 0
204
205
        padding_right_offset = 0
        for batch in batches:
206
            total_batch_size += len(batch)
207
            max_input_length = max(max_input_length, batch.max_input_length)
208
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
209
210
211

        # Batch attributes
        requests = []
212
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
213
        input_lengths = []
214
215
        offsets = []
        token_offsets = []
216
217
218
219
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
220
221
222
        # Batch tensors
        input_ids = None
        attention_mask = None
223
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
224
225
        past_key_values = []

226
227
228
229
230
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
231
            input_lengths.extend(batch.input_lengths)
232
233
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
234
235
236
237
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

238
239
240
241
242
243
244
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

245
            # Slicing end index for this batch
246
            end_index = start_index + len(batch)
247
248

            # We only concatenate batches that did at least one step
249
250
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
251

OlivierDehaene's avatar
OlivierDehaene committed
252
253
254
255
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
256
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
260
261
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
262
                attention_mask = batch.attention_mask.new_zeros(
263
                    (total_batch_size, max_input_length + padding_right_offset),
264
265
266
                )

            # We need to slice the attention mask to remove padding from previous steps
267
            # and to remove unused allocated space
268
            left_offset = max_input_length - batch.max_input_length
269
            batch_left_offset = (
270
                batch.attention_mask.shape[1]
271
                - batch.max_input_length
272
                - batch.padding_right_offset
273
            )
OlivierDehaene's avatar
OlivierDehaene committed
274
            attention_mask[
275
276
277
278
279
280
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
281

282
283
284
285
286
287
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
288
            for j, past in enumerate(batch.past_key_values):
289
290
                past_keys, past_values = past

291
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
292
293
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
294
295
                past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
296

297
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
298

299
                padded_past_values_shape = (
300
301
                    total_batch_size,
                    num_heads,
302
                    max_input_length - 1,
303
                    head_dim,
304
305
                )

306
307
308
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
309
                    # seq_length is last for BLOOM
310
                    padded_past_keys_shape = (
311
312
313
                        total_batch_size,
                        num_heads,
                        head_dim,
314
                        max_input_length - 1,
315
316
                    )

317
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
318
                if j == len(past_key_values):
319
320
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
321
322
323
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
324
                if batch.keys_head_dim_last:
325
                    past_key_values[j][0][
326
327
                        start_index:end_index,
                        :,
328
                        -(batch.max_input_length - 1) :,
329
                        :,
330
                    ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
331
                else:
332
                    past_key_values[j][0][
333
334
335
                        start_index:end_index,
                        :,
                        :,
336
337
                        -(batch.max_input_length - 1) :,
                    ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
338
339

                past_key_values[j][1][
340
341
                    start_index:end_index, :, -(batch.max_input_length - 1) :, :
                ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
342

343
            start_index += len(batch)
344
345
346
347

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
348
            requests_idx_mapping=requests_idx_mapping,
349
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
350
            attention_mask=attention_mask,
351
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
352
            past_key_values=past_key_values,
353
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
354
            input_lengths=input_lengths,
355
356
            offsets=offsets,
            token_offsets=token_offsets,
357
358
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
359
            max_input_length=max_input_length,
360
            padding_right_offset=padding_right_offset,
361
            keys_head_dim_last=batches[0].keys_head_dim_last,
362
        )
363

364
365
366
    def __len__(self):
        return len(self.requests)

367
368

class CausalLM(Model):
369
370
371
372
373
374
375
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
376
377
378
379
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
380
381
382
            if quantize:
                raise ValueError("quantization is not available on CPU")

383
384
385
            device = torch.device("cpu")
            dtype = torch.float32

386
        tokenizer = AutoTokenizer.from_pretrained(
387
            model_id, revision=revision, padding_side="left", truncation_side="left"
388
        )
389
        self.model = AutoModelForCausalLM.from_pretrained(
390
            model_id,
391
            revision=revision,
392
393
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
394
            load_in_8bit=quantize,
395
        ).eval()
396
397
398
399
400
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
401

402
        super(CausalLM, self).__init__(
403
404
405
406
407
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
408
409
410
411
412
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
413

414
415
416
417
418
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

419
    def forward(
420
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
421
422
423
424
425
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
426
            position_ids=position_ids,
427
428
429
430
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
431

432
    @tracer.start_as_current_span("generate_token")
433
434
    def generate_token(
        self, batch: CausalLMBatch
435
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
436
437
438
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

439
440
        logits, past = self.forward(
            batch.input_ids,
441
            attention_mask,
442
443
            batch.position_ids,
            batch.past_key_values,
444
445
        )

446
447
        # Results
        generations: List[Generation] = []
448
        stopped = True
449
450
451
452

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
453
            batch.input_lengths,
454
455
            batch.offsets,
            batch.token_offsets,
456
457
458
459
460
461
462
463
464
465
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
466
467
            offset,
            token_offset,
468
469
470
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
471
            all_input_ids,
472
473
        ) in enumerate(iterator):
            # Select next token
474
475
476
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
477
478

            # Append next token to all tokens
479
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
480
481
            new_input_length = input_length + 1

482
483
484
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
485
486
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
487
            )
488
489

            # Evaluate stopping criteria
490
            stop, reason = stopping_criteria(
491
492
                next_token_id_squeezed,
                next_token_text,
493
            )
494

495
            if stop:
496
                # Decode generated tokens
497
                output_text = self.decode(
498
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
499
                )
500
501
502
503
504
505
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

506
507
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
508
509
                )
            else:
510
511
                # Keep request in the batch
                generated_text = None
512
                stopped = False
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
538
                next_token_id_squeezed.item() in self.all_special_ids,
539
540
541
542
543
                generated_text,
            )

            generations.append(generation)

544
545
546
547
548
549
550
551
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.max_input_length = max(batch.max_input_length, new_input_length)

552
        # We finished all generations in the batch; there is no next batch
553
        if stopped:
554
            return generations, None
555

556
557
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
558

559
        # Update attention_mask as we added a new token to input_ids
560
561
562
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
563

564
        # Update position_ids
565
566
567
568
569
570
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch