causal_lm.py 18.8 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
6
from typing import Optional, Tuple, List, Type
7
8

from text_generation.models import Model
9
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
10
from text_generation.pb import generate_pb2
11
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
12

13
14
tracer = trace.get_tracer(__name__)

15
16

@dataclass
17
class CausalLMBatch(Batch):
18
19
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
20
21
22
23

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
24
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
    past_key_values: Optional[List[Tuple]]

    # All tokens
28
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
29
30
31
32
33

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
34
35
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
36
37

    # Metadata used for padding
38
39
    size: int
    max_sequence_length: int
40
    padding_right_offset: int
41

42
43
44
    # Past metadata
    keys_head_dim_last: bool = True

45
    def to_pb(self) -> generate_pb2.Batch:
46
47
48
49
50
51
52
53
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
54
55
56
57
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
58
59
60
61
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
62
        input_lengths = []
63
64

        # Parse batch
65
66
        max_sequence_length = 0
        padding_right_offset = 0
67
68
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
69
            input_lengths.append(r.input_length)
70
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
71
72
73
74
75
76
77
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            max_sequence_length = max(max_sequence_length, r.input_length)
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
78
79
            )

OlivierDehaene's avatar
OlivierDehaene committed
80
        tokenized_inputs = tokenizer(
81
82
83
            inputs,
            return_tensors="pt",
            padding=True,
84
            return_token_type_ids=False,
85
        ).to(device)
86
87
88
89
90
91
92
93
94

        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
            (pb.size, max_sequence_length + padding_right_offset)
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
        attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]

95
96
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
97
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
98
99
100
101

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
102
103
            input_ids=input_ids,
            attention_mask=attention_mask,
104
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
105
            past_key_values=None,
106
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
107
            input_lengths=input_lengths,
108
109
110
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
111
112
            max_sequence_length=max_sequence_length,
            padding_right_offset=padding_right_offset,
113
114
115
        )

    @classmethod
116
    @tracer.start_as_current_span("concatenate")
117
118
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
119
120
121
122
123
124
125
        total_batch_size = 0
        max_sequence_length = 0
        padding_right_offset = 0
        for batch in batches:
            total_batch_size += batch.size
            max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
126
127
128

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
129
        input_lengths = []
130
131
132
133
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
134
135
136
        # Batch tensors
        input_ids = None
        attention_mask = None
137
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
138
139
        past_key_values = []

140
141
142
143
144
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
145
            input_lengths.extend(batch.input_lengths)
146
147
148
149
150
151
152
153
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
154
155
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
156

OlivierDehaene's avatar
OlivierDehaene committed
157
158
159
160
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
161
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
162
163
164
165
166
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
167
                attention_mask = batch.attention_mask.new_zeros(
168
                    (total_batch_size, max_sequence_length + padding_right_offset),
169
170
171
                )

            # We need to slice the attention mask to remove padding from previous steps
172
173
174
            # and to remove unused allocated space
            left_offset = max_sequence_length - batch.max_sequence_length
            batch_left_offset = (
175
176
177
                batch.attention_mask.shape[1]
                - batch.max_sequence_length
                - batch.padding_right_offset
178
            )
OlivierDehaene's avatar
OlivierDehaene committed
179
            attention_mask[
180
181
182
183
184
185
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
186

187
188
189
190
191
192
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
193
            for j, past in enumerate(batch.past_key_values):
194
195
                past_keys, past_values = past

196
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
197
198
199
200
201
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

202
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
203

204
                padded_past_values_shape = (
205
206
207
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
208
                    head_dim,
209
210
                )

211
212
213
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
214
                    # seq_length is last for BLOOM
215
                    padded_past_keys_shape = (
216
217
218
                        total_batch_size,
                        num_heads,
                        head_dim,
219
                        max_sequence_length - 1,
220
221
                    )

222
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
223
                if j == len(past_key_values):
224
225
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
226
227
228
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
229
                if batch.keys_head_dim_last:
230
                    past_key_values[j][0][
231
232
233
234
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
235
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
236
                else:
237
                    past_key_values[j][0][
238
239
240
241
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
242
243
244
245
246
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
247
248
249
250
251
252
253

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
254
            attention_mask=attention_mask,
255
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
256
            past_key_values=past_key_values,
257
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
258
            input_lengths=input_lengths,
259
260
261
262
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
263
            padding_right_offset=padding_right_offset,
264
            keys_head_dim_last=batches[0].keys_head_dim_last,
265
        )
266

267
268
269
    def __len__(self):
        return len(self.requests)

270
271

class CausalLM(Model):
272
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
273
274
275
276
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
277
278
279
            if quantize:
                raise ValueError("quantization is not available on CPU")

280
281
282
            device = torch.device("cpu")
            dtype = torch.float32

283
        tokenizer = AutoTokenizer.from_pretrained(
284
            model_id, revision=revision, padding_side="left"
285
        )
286
        self.model = AutoModelForCausalLM.from_pretrained(
287
            model_id,
288
            revision=revision,
289
290
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
291
            load_in_8bit=quantize,
292
        ).eval()
293
294
295
296
297
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
298

299
300
301
302
303
304
305
306
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
307

308
309
310
311
312
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

313
    def forward(
314
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
315
316
317
318
319
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
320
            position_ids=position_ids,
321
322
323
324
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
325

326
    @tracer.start_as_current_span("generate_token")
327
328
    def generate_token(
        self, batch: CausalLMBatch
329
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
330
331
332
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

333
334
        logits, past = self.forward(
            batch.input_ids,
335
            attention_mask,
336
337
            batch.position_ids,
            batch.past_key_values,
338
339
340
341
342
        )

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
343
344
        # New values for next forward
        next_batch_input_lengths = []
345
346
347
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
348
        # Metadata
349
350
351
        next_batch_size = 0
        next_batch_max_sequence_length = 0

352
353
        # Results
        generations: List[Generation] = []
354
355
356
357

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
358
            batch.input_lengths,
359
360
361
362
363
364
365
366
367
368
369
370
371
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
372
            all_input_ids,
373
374
        ) in enumerate(iterator):
            # Select next token
375
376
377
            next_token_id, logprobs = next_token_chooser(
                all_input_ids.view(1, -1), logits
            )
378
379

            # Append next token to all tokens
380
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
381
382
            new_input_length = input_length + 1

383
384
385
386
387
388
389
390
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_text = self.tokenizer.decode(
                next_token_id_squeezed,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=False,
            )
391
392

            # Evaluate stopping criteria
393
            stop, reason = stopping_criteria(
394
395
                next_token_id_squeezed,
                next_token_text,
396
            )
397

398
            if stop:
399
                # Decode generated tokens
400
                output_text = self.decode(
401
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
402
                )
403
404
405
406
407
408
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

409
410
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
411
412
                )
            else:
413
414
                # Keep request in the batch
                generated_text = None
415
                next_batch_keep_indices.append(i)
416
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
417
                next_batch_all_input_ids.append(all_input_ids)
418
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
419
                next_batch_input_lengths.append(new_input_length)
420
421
422
423
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
448
                next_token_id_squeezed.item() in self.all_special_ids,
449
450
451
452
453
                generated_text,
            )

            generations.append(generation)

454
455
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
456
            return generations, None
457

OlivierDehaene's avatar
OlivierDehaene committed
458
459
460
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
461
        if len(next_batch_keep_indices) != len(batch):
462
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
463
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
464
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
465
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
466
            next_batch_past_key_values = [
467
                [
468
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
469
470
471
472
473
474
475
476
477
478
479
480
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
481
            next_batch_attention_mask = batch.attention_mask
482
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
483
            next_batch_past_key_values = past
484
485
486
487
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

488
489
        # Update attention_mask as we added a new token to input_ids
        next_batch_attention_mask[:, -batch.padding_right_offset] = 1
490

491
492
493
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

494
495
496
497
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
498
            attention_mask=next_batch_attention_mask,
499
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
500
            past_key_values=next_batch_past_key_values,
501
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
502
            input_lengths=next_batch_input_lengths,
503
504
505
506
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
507
            padding_right_offset=batch.padding_right_offset - 1,
508
            keys_head_dim_last=batch.keys_head_dim_last,
509
        )
510
        return generations, next_batch