"vscode:/vscode.git/clone" did not exist on "05f656c01fb6ce5e9bddde59323f95887ee9a13b"
causal_lm.py 18.8 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
6
from typing import Optional, Tuple, List, Type
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
28

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
29
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
30
31
32
    past_key_values: Optional[List[Tuple]]

    # All tokens
33
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36
37
38

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
39
40
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
41
42

    # Metadata used for padding
43
44
    size: int
    max_sequence_length: int
45
    padding_right_offset: int
46

47
48
49
    # Past metadata
    keys_head_dim_last: bool = True

50
    def to_pb(self) -> generate_pb2.Batch:
51
52
53
54
55
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
59
60
61
62
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
63
64
65
66
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
67
        input_lengths = []
68
69

        # Parse batch
70
71
        max_sequence_length = 0
        padding_right_offset = 0
72
73
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
74
            input_lengths.append(r.input_length)
75
76
77
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
            )
78
79
80
81
82
83
84
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            max_sequence_length = max(max_sequence_length, r.input_length)
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
85
86
            )

OlivierDehaene's avatar
OlivierDehaene committed
87
        tokenized_inputs = tokenizer(
88
89
90
            inputs,
            return_tensors="pt",
            padding=True,
91
            return_token_type_ids=False,
92
        ).to(device)
93
94
95
96
97
98
99
100
101

        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
            (pb.size, max_sequence_length + padding_right_offset)
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
        attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]

102
103
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
104
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
105
106
107
108

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
109
110
            input_ids=input_ids,
            attention_mask=attention_mask,
111
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
112
            past_key_values=None,
113
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
114
            input_lengths=input_lengths,
115
116
117
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
118
119
            max_sequence_length=max_sequence_length,
            padding_right_offset=padding_right_offset,
120
121
122
        )

    @classmethod
123
    @tracer.start_as_current_span("concatenate")
124
125
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
126
127
128
129
130
131
132
        total_batch_size = 0
        max_sequence_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
133
134
135

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
136
        input_lengths = []
137
138
139
140
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
141
142
143
        # Batch tensors
        input_ids = None
        attention_mask = None
144
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
145
146
        past_key_values = []

147
148
149
150
151
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
152
            input_lengths.extend(batch.input_lengths)
153
154
155
156
157
158
159
160
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
161
162
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
163

OlivierDehaene's avatar
OlivierDehaene committed
164
165
166
167
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
168
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
169
170
171
172
173
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
174
                attention_mask = batch.attention_mask.new_zeros(
175
                    (total_batch_size, max_sequence_length + padding_right_offset),
176
177
178
                )

            # We need to slice the attention mask to remove padding from previous steps
179
180
181
            # and to remove unused allocated space
            left_offset = max_sequence_length - batch.max_sequence_length
            batch_left_offset = (
182
183
184
                batch.attention_mask.shape[1]
                - batch.max_sequence_length
                - batch.padding_right_offset
185
            )
OlivierDehaene's avatar
OlivierDehaene committed
186
            attention_mask[
187
188
189
190
191
192
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
193

194
195
196
197
198
199
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
200
            for j, past in enumerate(batch.past_key_values):
201
202
                past_keys, past_values = past

203
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
204
205
206
207
208
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

209
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
210

211
                padded_past_values_shape = (
212
213
214
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
215
                    head_dim,
216
217
                )

218
219
220
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
221
                    # seq_length is last for BLOOM
222
                    padded_past_keys_shape = (
223
224
225
                        total_batch_size,
                        num_heads,
                        head_dim,
226
                        max_sequence_length - 1,
227
228
                    )

229
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
230
                if j == len(past_key_values):
231
232
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
233
234
235
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
236
                if batch.keys_head_dim_last:
237
                    past_key_values[j][0][
238
239
240
241
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
242
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
243
                else:
244
                    past_key_values[j][0][
245
246
247
248
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
249
250
251
252
253
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
254
255
256
257
258
259
260

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
261
            attention_mask=attention_mask,
262
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
263
            past_key_values=past_key_values,
264
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
265
            input_lengths=input_lengths,
266
267
268
269
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
270
            padding_right_offset=padding_right_offset,
271
            keys_head_dim_last=batches[0].keys_head_dim_last,
272
        )
273

274
275
276
    def __len__(self):
        return len(self.requests)

277
278

class CausalLM(Model):
279
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
280
281
282
283
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
284
285
286
            if quantize:
                raise ValueError("quantization is not available on CPU")

287
288
289
            device = torch.device("cpu")
            dtype = torch.float32

290
        tokenizer = AutoTokenizer.from_pretrained(
291
            model_id, revision=revision, padding_side="left"
292
        )
293
        self.model = AutoModelForCausalLM.from_pretrained(
294
            model_id,
295
            revision=revision,
296
297
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
298
            load_in_8bit=quantize,
299
        ).eval()
300
301
302
303
304
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
305

306
307
308
309
310
311
312
313
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
314

315
316
317
318
319
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

320
    def forward(
321
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
322
323
324
325
326
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
327
            position_ids=position_ids,
328
329
330
331
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
332

333
    @tracer.start_as_current_span("generate_token")
334
335
    def generate_token(
        self, batch: CausalLMBatch
336
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
337
338
339
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

340
341
        logits, past = self.forward(
            batch.input_ids,
342
            attention_mask,
343
344
            batch.position_ids,
            batch.past_key_values,
345
346
347
348
349
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
350
351
        # New values for next forward
        next_batch_input_lengths = []
352
353
354
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
355
        # Metadata
356
357
358
        next_batch_size = 0
        next_batch_max_sequence_length = 0

359
360
        # Results
        generations: List[Generation] = []
361
362
363
364

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
365
            batch.input_lengths,
366
367
368
369
370
371
372
373
374
375
376
377
378
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
379
            all_input_ids,
380
381
        ) in enumerate(iterator):
            # Select next token
382
383
384
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
385
386

            # Append next token to all tokens
387
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
388
389
            new_input_length = input_length + 1

390
391
392
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
393
            next_token_text = self.decode_token(
394
395
                next_token_id_squeezed,
            )
396
397

            # Evaluate stopping criteria
398
            stop, reason = stopping_criteria(
399
400
                next_token_id_squeezed,
                next_token_text,
401
            )
402

403
            if stop:
404
                # Decode generated tokens
405
                output_text = self.decode(
406
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
407
                )
408
409
410
411
412
413
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

414
415
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
416
417
                )
            else:
418
419
                # Keep request in the batch
                generated_text = None
420
                next_batch_keep_indices.append(i)
421
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
422
                next_batch_all_input_ids.append(all_input_ids)
423
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
424
                next_batch_input_lengths.append(new_input_length)
425
426
427
428
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
453
                next_token_id_squeezed.item() in self.all_special_ids,
454
455
456
457
458
                generated_text,
            )

            generations.append(generation)

459
460
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
461
            return generations, None
462

OlivierDehaene's avatar
OlivierDehaene committed
463
464
465
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
466
        if len(next_batch_keep_indices) != len(batch):
467
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
468
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
469
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
470
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
471
            next_batch_past_key_values = [
472
                [
473
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
474
475
476
477
478
479
480
481
482
483
484
485
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
486
            next_batch_attention_mask = batch.attention_mask
487
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
488
            next_batch_past_key_values = past
489
490
491
492
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

493
494
        # Update attention_mask as we added a new token to input_ids
        next_batch_attention_mask[:, -batch.padding_right_offset] = 1
495

496
497
498
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

499
500
501
502
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
503
            attention_mask=next_batch_attention_mask,
504
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
505
            past_key_values=next_batch_past_key_values,
506
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
507
            input_lengths=next_batch_input_lengths,
508
509
510
511
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
512
            padding_right_offset=batch.padding_right_offset - 1,
513
            keys_head_dim_last=batch.keys_head_dim_last,
514
        )
515
        return generations, next_batch