causal_lm.py 18.6 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
6
from typing import Optional, Tuple, List, Type
7

8
9
10
11
12
13
14
15
16
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
17

18
19
tracer = trace.get_tracer(__name__)

20
21

@dataclass
22
class CausalLMBatch(Batch):
23
24
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
28

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
29
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
30
31
32
    past_key_values: Optional[List[Tuple]]

    # All tokens
33
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36
37
38

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
39
40
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
41
42

    # Metadata used for padding
43
    size: int
44
    max_input_length: int
45
    padding_right_offset: int
46

47
48
49
    # Past metadata
    keys_head_dim_last: bool = True

50
    def to_pb(self) -> generate_pb2.Batch:
51
52
53
54
55
56
57
58
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
59
60
61
62
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
63
64
65
66
67
68
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []

        # Parse batch
69
        padding_right_offset = 0
70
71
        for r in pb.requests:
            inputs.append(r.inputs)
72
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
73
74
75
76
77
78
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
79
80
            )

OlivierDehaene's avatar
OlivierDehaene committed
81
        tokenized_inputs = tokenizer(
82
83
84
            inputs,
            return_tensors="pt",
            padding=True,
85
            return_token_type_ids=False,
86
        ).to(device)
87

88
89
90
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

91
92
93
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
94
            (pb.size, max_input_length + padding_right_offset)
95
96
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
97
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
98

99
100
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
101
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
102
103
104
105

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
106
107
            input_ids=input_ids,
            attention_mask=attention_mask,
108
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
109
            past_key_values=None,
110
            all_input_ids=all_input_ids,
111
            input_lengths=input_lengths.tolist(),
112
113
114
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
115
            max_input_length=max_input_length.item(),
116
            padding_right_offset=padding_right_offset,
117
118
119
        )

    @classmethod
120
    @tracer.start_as_current_span("concatenate")
121
122
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
123
        total_batch_size = 0
124
        max_input_length = 0
125
126
127
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
128
            max_input_length = max(max_input_length, batch.max_input_length)
129
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
130
131
132

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
133
        input_lengths = []
134
135
136
137
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
138
139
140
        # Batch tensors
        input_ids = None
        attention_mask = None
141
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
142
143
        past_key_values = []

144
145
146
147
148
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
149
            input_lengths.extend(batch.input_lengths)
150
151
152
153
154
155
156
157
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
158
159
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
160

OlivierDehaene's avatar
OlivierDehaene committed
161
162
163
164
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
165
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
166
167
168
169
170
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
171
                attention_mask = batch.attention_mask.new_zeros(
172
                    (total_batch_size, max_input_length + padding_right_offset),
173
174
175
                )

            # We need to slice the attention mask to remove padding from previous steps
176
            # and to remove unused allocated space
177
            left_offset = max_input_length - batch.max_input_length
178
            batch_left_offset = (
179
                batch.attention_mask.shape[1]
180
                - batch.max_input_length
181
                - batch.padding_right_offset
182
            )
OlivierDehaene's avatar
OlivierDehaene committed
183
            attention_mask[
184
185
186
187
188
189
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
190

191
192
193
194
195
196
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
197
            for j, past in enumerate(batch.past_key_values):
198
199
                past_keys, past_values = past

200
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
201
202
203
204
205
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

206
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
207

208
                padded_past_values_shape = (
209
210
                    total_batch_size,
                    num_heads,
211
                    max_input_length - 1,
212
                    head_dim,
213
214
                )

215
216
217
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
218
                    # seq_length is last for BLOOM
219
                    padded_past_keys_shape = (
220
221
222
                        total_batch_size,
                        num_heads,
                        head_dim,
223
                        max_input_length - 1,
224
225
                    )

226
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
227
                if j == len(past_key_values):
228
229
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
230
231
232
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
233
                if batch.keys_head_dim_last:
234
                    past_key_values[j][0][
235
236
                        start_index:end_index,
                        :,
237
                        -(batch.max_input_length - 1) :,
238
                        :,
239
                    ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
240
                else:
241
                    past_key_values[j][0][
242
243
244
                        start_index:end_index,
                        :,
                        :,
245
246
                        -(batch.max_input_length - 1) :,
                    ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
247
248

                past_key_values[j][1][
249
250
                    start_index:end_index, :, -(batch.max_input_length - 1) :, :
                ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
251
252
253
254
255
256
257

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
258
            attention_mask=attention_mask,
259
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
260
            past_key_values=past_key_values,
261
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
262
            input_lengths=input_lengths,
263
264
265
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
266
            max_input_length=max_input_length,
267
            padding_right_offset=padding_right_offset,
268
            keys_head_dim_last=batches[0].keys_head_dim_last,
269
        )
270

271
272
273
    def __len__(self):
        return len(self.requests)

274
275

class CausalLM(Model):
276
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
277
278
279
280
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
281
282
283
            if quantize:
                raise ValueError("quantization is not available on CPU")

284
285
286
            device = torch.device("cpu")
            dtype = torch.float32

287
        tokenizer = AutoTokenizer.from_pretrained(
288
            model_id, revision=revision, padding_side="left"
289
        )
290
        self.model = AutoModelForCausalLM.from_pretrained(
291
            model_id,
292
            revision=revision,
293
294
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
295
            load_in_8bit=quantize,
296
        ).eval()
297
298
299
300
301
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
302

303
304
305
306
307
308
309
310
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
311

312
313
314
315
316
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

317
    def forward(
318
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
319
320
321
322
323
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
324
            position_ids=position_ids,
325
326
327
328
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
329

330
    @tracer.start_as_current_span("generate_token")
331
332
    def generate_token(
        self, batch: CausalLMBatch
333
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
334
335
336
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

337
338
        logits, past = self.forward(
            batch.input_ids,
339
            attention_mask,
340
341
            batch.position_ids,
            batch.past_key_values,
342
343
344
345
346
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
347
348
        # New values for next forward
        next_batch_input_lengths = []
349
350
351
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
352
        # Metadata
353
        next_batch_size = 0
354
        next_batch_max_input_length = 0
355

356
357
        # Results
        generations: List[Generation] = []
358
359
360
361

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
362
            batch.input_lengths,
363
364
365
366
367
368
369
370
371
372
373
374
375
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
376
            all_input_ids,
377
378
        ) in enumerate(iterator):
            # Select next token
379
380
381
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
382
383

            # Append next token to all tokens
384
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
385
386
            new_input_length = input_length + 1

387
388
389
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
390
            next_token_text = self.decode_token(
391
392
                next_token_id_squeezed,
            )
393
394

            # Evaluate stopping criteria
395
            stop, reason = stopping_criteria(
396
397
                next_token_id_squeezed,
                next_token_text,
398
            )
399

400
            if stop:
401
                # Decode generated tokens
402
                output_text = self.decode(
403
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
404
                )
405
406
407
408
409
410
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

411
412
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
413
414
                )
            else:
415
416
                # Keep request in the batch
                generated_text = None
417
                next_batch_keep_indices.append(i)
418
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
419
                next_batch_all_input_ids.append(all_input_ids)
420
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
421
                next_batch_input_lengths.append(new_input_length)
422
423
                next_batch_max_input_length = max(
                    next_batch_max_input_length, new_input_length
424
425
                )

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
450
                next_token_id_squeezed.item() in self.all_special_ids,
451
452
453
454
455
                generated_text,
            )

            generations.append(generation)

456
457
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
458
            return generations, None
459

OlivierDehaene's avatar
OlivierDehaene committed
460
461
462
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
463
        if len(next_batch_keep_indices) != len(batch):
464
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
465
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
466
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
467
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
468
            next_batch_past_key_values = [
469
                [
470
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
471
472
473
474
475
476
477
478
479
480
481
482
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
483
            next_batch_attention_mask = batch.attention_mask
484
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
485
            next_batch_past_key_values = past
486
487
488
489
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

490
491
        # Update attention_mask as we added a new token to input_ids
        next_batch_attention_mask[:, -batch.padding_right_offset] = 1
492

493
494
495
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

496
497
498
499
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
500
            attention_mask=next_batch_attention_mask,
501
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
502
            past_key_values=next_batch_past_key_values,
503
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
504
            input_lengths=next_batch_input_lengths,
505
506
507
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
508
            max_input_length=next_batch_max_input_length,
509
            padding_right_offset=batch.padding_right_offset - 1,
510
            keys_head_dim_last=batch.keys_head_dim_last,
511
        )
512
        return generations, next_batch