causal_lm.py 19.8 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
6
from typing import Optional, Tuple, List, Type
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
28

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
29
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
30
31
32
    past_key_values: Optional[List[Tuple]]

    # All tokens
33
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36

    # Lengths of all generations present in the batch
    input_lengths: List[int]
37
38
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
OlivierDehaene's avatar
OlivierDehaene committed
39
40

    # Generation helpers
41
42
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
43
44

    # Metadata used for padding
45
    size: int
46
    max_input_length: int
47
    padding_right_offset: int
48

49
50
51
    # Past metadata
    keys_head_dim_last: bool = True

52
    def to_pb(self) -> generate_pb2.Batch:
53
54
55
56
57
58
59
60
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
61
62
63
64
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
69
70
        offsets = []
        token_offsets = []
71
72

        # Parse batch
73
        max_truncation = 0
74
        padding_right_offset = 0
75
76
        for r in pb.requests:
            inputs.append(r.inputs)
77
78
            offsets.append(None)
            token_offsets.append(None)
79
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
80
81
82
83
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
84
            max_truncation = max(max_truncation, r.truncate)
85
86
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
87
88
            )

OlivierDehaene's avatar
OlivierDehaene committed
89
        tokenized_inputs = tokenizer(
90
91
92
            inputs,
            return_tensors="pt",
            padding=True,
93
            return_token_type_ids=False,
94
95
            truncation=True,
            max_length=max_truncation,
96
        ).to(device)
97

98
99
100
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

101
102
103
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
104
            (pb.size, max_input_length + padding_right_offset)
105
106
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
107
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
108

109
110
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
111
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
112
113
114
115

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
116
117
            input_ids=input_ids,
            attention_mask=attention_mask,
118
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
119
            past_key_values=None,
120
            all_input_ids=all_input_ids,
121
            input_lengths=input_lengths.tolist(),
122
123
            offsets=offsets,
            token_offsets=token_offsets,
124
125
126
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
127
            max_input_length=max_input_length.item(),
128
            padding_right_offset=padding_right_offset,
129
130
131
        )

    @classmethod
132
    @tracer.start_as_current_span("concatenate")
133
134
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
135
        total_batch_size = 0
136
        max_input_length = 0
137
138
139
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
140
            max_input_length = max(max_input_length, batch.max_input_length)
141
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
142
143
144

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
145
        input_lengths = []
146
147
        offsets = []
        token_offsets = []
148
149
150
151
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
152
153
154
        # Batch tensors
        input_ids = None
        attention_mask = None
155
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
156
157
        past_key_values = []

158
159
160
161
162
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
163
            input_lengths.extend(batch.input_lengths)
164
165
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
166
167
168
169
170
171
172
173
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
174
175
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
176

OlivierDehaene's avatar
OlivierDehaene committed
177
178
179
180
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
181
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
182
183
184
185
186
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
187
                attention_mask = batch.attention_mask.new_zeros(
188
                    (total_batch_size, max_input_length + padding_right_offset),
189
190
191
                )

            # We need to slice the attention mask to remove padding from previous steps
192
            # and to remove unused allocated space
193
            left_offset = max_input_length - batch.max_input_length
194
            batch_left_offset = (
195
                batch.attention_mask.shape[1]
196
                - batch.max_input_length
197
                - batch.padding_right_offset
198
            )
OlivierDehaene's avatar
OlivierDehaene committed
199
            attention_mask[
200
201
202
203
204
205
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
206

207
208
209
210
211
212
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
213
            for j, past in enumerate(batch.past_key_values):
214
215
                past_keys, past_values = past

216
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
217
218
219
220
221
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

222
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
223

224
                padded_past_values_shape = (
225
226
                    total_batch_size,
                    num_heads,
227
                    max_input_length - 1,
228
                    head_dim,
229
230
                )

231
232
233
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
234
                    # seq_length is last for BLOOM
235
                    padded_past_keys_shape = (
236
237
238
                        total_batch_size,
                        num_heads,
                        head_dim,
239
                        max_input_length - 1,
240
241
                    )

242
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
243
                if j == len(past_key_values):
244
245
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
246
247
248
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
249
                if batch.keys_head_dim_last:
250
                    past_key_values[j][0][
251
252
                        start_index:end_index,
                        :,
253
                        -(batch.max_input_length - 1) :,
254
                        :,
255
                    ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
256
                else:
257
                    past_key_values[j][0][
258
259
260
                        start_index:end_index,
                        :,
                        :,
261
262
                        -(batch.max_input_length - 1) :,
                    ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
263
264

                past_key_values[j][1][
265
266
                    start_index:end_index, :, -(batch.max_input_length - 1) :, :
                ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
267
268
269
270
271
272
273

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
274
            attention_mask=attention_mask,
275
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
276
            past_key_values=past_key_values,
277
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
278
            input_lengths=input_lengths,
279
280
            offsets=offsets,
            token_offsets=token_offsets,
281
282
283
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
284
            max_input_length=max_input_length,
285
            padding_right_offset=padding_right_offset,
286
            keys_head_dim_last=batches[0].keys_head_dim_last,
287
        )
288

289
290
291
    def __len__(self):
        return len(self.requests)

292
293

class CausalLM(Model):
294
295
296
297
298
299
300
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: bool = False,
        decode_buffer: int = 3,
    ):
301
302
303
304
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
305
306
307
            if quantize:
                raise ValueError("quantization is not available on CPU")

308
309
310
            device = torch.device("cpu")
            dtype = torch.float32

311
        tokenizer = AutoTokenizer.from_pretrained(
312
            model_id, revision=revision, padding_side="left", truncation_side="left"
313
        )
314
        self.model = AutoModelForCausalLM.from_pretrained(
315
            model_id,
316
            revision=revision,
317
318
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
319
            load_in_8bit=quantize,
320
        ).eval()
321
322
323
324
325
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
326

327
        super(CausalLM, self).__init__(
328
            tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
329
330
331
332
333
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
334

335
336
337
338
339
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

340
    def forward(
341
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
342
343
344
345
346
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
347
            position_ids=position_ids,
348
349
350
351
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
352

353
    @tracer.start_as_current_span("generate_token")
354
355
    def generate_token(
        self, batch: CausalLMBatch
356
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
357
358
359
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

360
361
        logits, past = self.forward(
            batch.input_ids,
362
            attention_mask,
363
364
            batch.position_ids,
            batch.past_key_values,
365
366
367
368
369
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
370
371
        # New values for next forward
        next_batch_input_lengths = []
372
373
        next_batch_offsets = []
        next_batch_token_offsets = []
374
375
376
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
377
        # Metadata
378
        next_batch_size = 0
379
        next_batch_max_input_length = 0
380

381
382
        # Results
        generations: List[Generation] = []
383
384
385
386

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
387
            batch.input_lengths,
388
389
            batch.offsets,
            batch.token_offsets,
390
391
392
393
394
395
396
397
398
399
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
400
401
            offset,
            token_offset,
402
403
404
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
405
            all_input_ids,
406
407
        ) in enumerate(iterator):
            # Select next token
408
409
410
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
411
412

            # Append next token to all tokens
413
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
414
415
            new_input_length = input_length + 1

416
417
418
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
419
420
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids[:, 0], offset, token_offset
421
            )
422
423

            # Evaluate stopping criteria
424
            stop, reason = stopping_criteria(
425
426
                next_token_id_squeezed,
                next_token_text,
427
            )
428

429
            if stop:
430
                # Decode generated tokens
431
                output_text = self.decode(
432
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
433
                )
434
435
436
437
438
439
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

440
441
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
442
443
                )
            else:
444
445
                # Keep request in the batch
                generated_text = None
446
                next_batch_keep_indices.append(i)
447
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
448
                next_batch_all_input_ids.append(all_input_ids)
449
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
450
                next_batch_input_lengths.append(new_input_length)
451
452
                next_batch_offsets.append(offset)
                next_batch_token_offsets.append(token_offset)
453
454
                next_batch_max_input_length = max(
                    next_batch_max_input_length, new_input_length
455
456
                )

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
481
                next_token_id_squeezed.item() in self.all_special_ids,
482
483
484
485
486
                generated_text,
            )

            generations.append(generation)

487
488
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
489
            return generations, None
490

OlivierDehaene's avatar
OlivierDehaene committed
491
492
493
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
494
        if len(next_batch_keep_indices) != len(batch):
495
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
496
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
497
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
498
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
499
            next_batch_past_key_values = [
500
                [
501
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
502
503
504
505
506
507
508
509
510
511
512
513
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
514
            next_batch_attention_mask = batch.attention_mask
515
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
516
            next_batch_past_key_values = past
517
518
519
520
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

521
522
        # Update attention_mask as we added a new token to input_ids
        next_batch_attention_mask[:, -batch.padding_right_offset] = 1
523

524
525
526
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

527
528
529
530
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
531
            attention_mask=next_batch_attention_mask,
532
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
533
            past_key_values=next_batch_past_key_values,
534
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
535
            input_lengths=next_batch_input_lengths,
536
537
            offsets=next_batch_offsets,
            token_offsets=next_batch_token_offsets,
538
539
540
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
541
            max_input_length=next_batch_max_input_length,
542
            padding_right_offset=batch.padding_right_offset - 1,
543
            keys_head_dim_last=batch.keys_head_dim_last,
544
        )
545
        return generations, next_batch