causal_lm.py 18.8 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
6
from typing import Optional, Tuple, List, Type
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
28

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
29
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
30
31
32
    past_key_values: Optional[List[Tuple]]

    # All tokens
33
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36
37
38

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
39
40
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
41
42

    # Metadata used for padding
43
    size: int
44
    max_input_length: int
45
    padding_right_offset: int
46

47
48
49
    # Past metadata
    keys_head_dim_last: bool = True

50
    def to_pb(self) -> generate_pb2.Batch:
51
52
53
54
55
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
59
60
61
62
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
63
64
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        # Parse batch
69
        max_truncation = 0
70
        padding_right_offset = 0
71
72
        for r in pb.requests:
            inputs.append(r.inputs)
73
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
74
75
76
77
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
78
            max_truncation = max(max_truncation, r.truncate)
79
80
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
81
82
            )

OlivierDehaene's avatar
OlivierDehaene committed
83
        tokenized_inputs = tokenizer(
84
85
86
            inputs,
            return_tensors="pt",
            padding=True,
87
            return_token_type_ids=False,
88
89
            truncation=True,
            max_length=max_truncation,
90
        ).to(device)
91

92
93
94
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

95
96
97
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
98
            (pb.size, max_input_length + padding_right_offset)
99
100
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
101
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
102

103
104
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
105
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
106
107
108
109

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
110
111
            input_ids=input_ids,
            attention_mask=attention_mask,
112
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
113
            past_key_values=None,
114
            all_input_ids=all_input_ids,
115
            input_lengths=input_lengths.tolist(),
116
117
118
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
119
            max_input_length=max_input_length.item(),
120
            padding_right_offset=padding_right_offset,
121
122
123
        )

    @classmethod
124
    @tracer.start_as_current_span("concatenate")
125
126
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
127
        total_batch_size = 0
128
        max_input_length = 0
129
130
131
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
132
            max_input_length = max(max_input_length, batch.max_input_length)
133
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
134
135
136

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
137
        input_lengths = []
138
139
140
141
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
142
143
144
        # Batch tensors
        input_ids = None
        attention_mask = None
145
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
146
147
        past_key_values = []

148
149
150
151
152
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
153
            input_lengths.extend(batch.input_lengths)
154
155
156
157
158
159
160
161
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
162
163
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
164

OlivierDehaene's avatar
OlivierDehaene committed
165
166
167
168
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
169
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
170
171
172
173
174
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
175
                attention_mask = batch.attention_mask.new_zeros(
176
                    (total_batch_size, max_input_length + padding_right_offset),
177
178
179
                )

            # We need to slice the attention mask to remove padding from previous steps
180
            # and to remove unused allocated space
181
            left_offset = max_input_length - batch.max_input_length
182
            batch_left_offset = (
183
                batch.attention_mask.shape[1]
184
                - batch.max_input_length
185
                - batch.padding_right_offset
186
            )
OlivierDehaene's avatar
OlivierDehaene committed
187
            attention_mask[
188
189
190
191
192
193
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
194

195
196
197
198
199
200
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
201
            for j, past in enumerate(batch.past_key_values):
202
203
                past_keys, past_values = past

204
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
205
206
207
208
209
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

210
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
211

212
                padded_past_values_shape = (
213
214
                    total_batch_size,
                    num_heads,
215
                    max_input_length - 1,
216
                    head_dim,
217
218
                )

219
220
221
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
222
                    # seq_length is last for BLOOM
223
                    padded_past_keys_shape = (
224
225
226
                        total_batch_size,
                        num_heads,
                        head_dim,
227
                        max_input_length - 1,
228
229
                    )

230
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
231
                if j == len(past_key_values):
232
233
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
234
235
236
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
237
                if batch.keys_head_dim_last:
238
                    past_key_values[j][0][
239
240
                        start_index:end_index,
                        :,
241
                        -(batch.max_input_length - 1) :,
242
                        :,
243
                    ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
244
                else:
245
                    past_key_values[j][0][
246
247
248
                        start_index:end_index,
                        :,
                        :,
249
250
                        -(batch.max_input_length - 1) :,
                    ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
251
252

                past_key_values[j][1][
253
254
                    start_index:end_index, :, -(batch.max_input_length - 1) :, :
                ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
255
256
257
258
259
260
261

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
262
            attention_mask=attention_mask,
263
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
264
            past_key_values=past_key_values,
265
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
266
            input_lengths=input_lengths,
267
268
269
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
270
            max_input_length=max_input_length,
271
            padding_right_offset=padding_right_offset,
272
            keys_head_dim_last=batches[0].keys_head_dim_last,
273
        )
274

275
276
277
    def __len__(self):
        return len(self.requests)

278
279

class CausalLM(Model):
280
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
281
282
283
284
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
285
286
287
            if quantize:
                raise ValueError("quantization is not available on CPU")

288
289
290
            device = torch.device("cpu")
            dtype = torch.float32

291
        tokenizer = AutoTokenizer.from_pretrained(
292
            model_id, revision=revision, padding_side="left"
293
        )
294
        self.model = AutoModelForCausalLM.from_pretrained(
295
            model_id,
296
            revision=revision,
297
298
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
299
            load_in_8bit=quantize,
300
        ).eval()
301
302
303
304
305
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
306

307
308
309
310
311
312
313
314
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
315

316
317
318
319
320
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

321
    def forward(
322
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
323
324
325
326
327
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
328
            position_ids=position_ids,
329
330
331
332
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
333

334
    @tracer.start_as_current_span("generate_token")
335
336
    def generate_token(
        self, batch: CausalLMBatch
337
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
338
339
340
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

341
342
        logits, past = self.forward(
            batch.input_ids,
343
            attention_mask,
344
345
            batch.position_ids,
            batch.past_key_values,
346
347
348
349
350
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
351
352
        # New values for next forward
        next_batch_input_lengths = []
353
354
355
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
356
        # Metadata
357
        next_batch_size = 0
358
        next_batch_max_input_length = 0
359

360
361
        # Results
        generations: List[Generation] = []
362
363
364
365

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
366
            batch.input_lengths,
367
368
369
370
371
372
373
374
375
376
377
378
379
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
380
            all_input_ids,
381
382
        ) in enumerate(iterator):
            # Select next token
383
384
385
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
386
387

            # Append next token to all tokens
388
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
389
390
            new_input_length = input_length + 1

391
392
393
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
394
            next_token_text = self.decode_token(
395
396
                next_token_id_squeezed,
            )
397
398

            # Evaluate stopping criteria
399
            stop, reason = stopping_criteria(
400
401
                next_token_id_squeezed,
                next_token_text,
402
            )
403

404
            if stop:
405
                # Decode generated tokens
406
                output_text = self.decode(
407
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
408
                )
409
410
411
412
413
414
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

415
416
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
417
418
                )
            else:
419
420
                # Keep request in the batch
                generated_text = None
421
                next_batch_keep_indices.append(i)
422
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
423
                next_batch_all_input_ids.append(all_input_ids)
424
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
425
                next_batch_input_lengths.append(new_input_length)
426
427
                next_batch_max_input_length = max(
                    next_batch_max_input_length, new_input_length
428
429
                )

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
454
                next_token_id_squeezed.item() in self.all_special_ids,
455
456
457
458
459
                generated_text,
            )

            generations.append(generation)

460
461
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
462
            return generations, None
463

OlivierDehaene's avatar
OlivierDehaene committed
464
465
466
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
467
        if len(next_batch_keep_indices) != len(batch):
468
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
469
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
470
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
471
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
472
            next_batch_past_key_values = [
473
                [
474
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
475
476
477
478
479
480
481
482
483
484
485
486
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
487
            next_batch_attention_mask = batch.attention_mask
488
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
489
            next_batch_past_key_values = past
490
491
492
493
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

494
495
        # Update attention_mask as we added a new token to input_ids
        next_batch_attention_mask[:, -batch.padding_right_offset] = 1
496

497
498
499
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

500
501
502
503
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
504
            attention_mask=next_batch_attention_mask,
505
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
506
            past_key_values=next_batch_past_key_values,
507
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
508
            input_lengths=next_batch_input_lengths,
509
510
511
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
512
            max_input_length=next_batch_max_input_length,
513
            padding_right_offset=batch.padding_right_offset - 1,
514
            keys_head_dim_last=batch.keys_head_dim_last,
515
        )
516
        return generations, next_batch