causal_lm.py 17.5 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
5
from typing import Optional, Tuple, List, Type
6
7

from text_generation.models import Model
8
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
9
from text_generation.pb import generate_pb2
10
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
11
12
13


@dataclass
14
class CausalLMBatch(Batch):
15
16
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
17
18
19
20

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
21
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
22
23
24
    past_key_values: Optional[List[Tuple]]

    # All tokens
25
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29
30

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
31
32
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
33
34

    # Metadata used for padding
35
36
37
    size: int
    max_sequence_length: int

38
39
40
    # Past metadata
    keys_head_dim_last: bool = True

41
    def to_pb(self) -> generate_pb2.Batch:
42
43
44
45
46
47
48
49
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
50
51
52
53
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
54
55
56
57
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
58
        input_lengths = []
59
60
61
62

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
63
            input_lengths.append(r.input_length)
64
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
65
            stopping_criterias.append(
66
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
67
68
            )

69
        pad_to_multiple_of = 8 if device.type == "cuda" else None
OlivierDehaene's avatar
OlivierDehaene committed
70
        tokenized_inputs = tokenizer(
71
72
73
74
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
75
            return_token_type_ids=False,
76
        ).to(device)
77
78
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
79
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
80
81
82
83

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
OlivierDehaene's avatar
OlivierDehaene committed
84
85
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
86
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
87
            past_key_values=None,
88
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
89
            input_lengths=input_lengths,
90
91
92
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
OlivierDehaene's avatar
OlivierDehaene committed
93
            max_sequence_length=max(input_lengths),
94
95
96
97
98
99
100
101
102
103
        )

    @classmethod
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
104
        input_lengths = []
105
106
107
108
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
109
110
111
        # Batch tensors
        input_ids = None
        attention_mask = None
112
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
113
114
        past_key_values = []

115
116
117
118
119
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
120
            input_lengths.extend(batch.input_lengths)
121
122
123
124
125
126
127
128
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
129
130
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
131

OlivierDehaene's avatar
OlivierDehaene committed
132
133
134
135
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
136
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
137
138
139
140
141
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
142
                attention_mask = batch.attention_mask.new_zeros(
143
144
145
146
                    (total_batch_size, max_sequence_length),
                )

            # We need to slice the attention mask to remove padding from previous steps
OlivierDehaene's avatar
OlivierDehaene committed
147
            attention_mask[
148
                start_index:end_index, -batch.max_sequence_length :
OlivierDehaene's avatar
OlivierDehaene committed
149
            ] = batch.attention_mask[:, -batch.max_sequence_length :]
150

151
152
153
154
155
156
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
157
            for j, past in enumerate(batch.past_key_values):
158
159
                past_keys, past_values = past

160
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
161
162
163
164
165
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

166
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
167

168
                padded_past_values_shape = (
169
170
171
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
172
                    head_dim,
173
174
                )

175
176
177
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
178
                    # seq_length is last for BLOOM
179
                    padded_past_keys_shape = (
180
181
182
                        total_batch_size,
                        num_heads,
                        head_dim,
183
                        max_sequence_length - 1,
184
185
                    )

186
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
187
                if j == len(past_key_values):
188
189
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
190
191
192
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
193
                if batch.keys_head_dim_last:
194
                    past_key_values[j][0][
195
196
197
198
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
199
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
200
                else:
201
                    past_key_values[j][0][
202
203
204
205
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
206
207
208
209
210
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
211
212
213
214
215
216
217

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
218
            attention_mask=attention_mask,
219
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
220
            past_key_values=past_key_values,
221
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
222
            input_lengths=input_lengths,
223
224
225
226
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
227
            keys_head_dim_last=batches[0].keys_head_dim_last,
228
        )
229

230
231
232
    def __len__(self):
        return len(self.requests)

233
234

class CausalLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
235
    def __init__(self, model_name: str, quantize=False):
236
237
238
239
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
240
241
242
            if quantize:
                raise ValueError("quantization is not available on CPU")

243
244
245
246
247
248
249
250
            device = torch.device("cpu")
            dtype = torch.float32

        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
251
            load_in_8bit=quantize,
252
        ).eval()
253
254
255
256
257
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
258

259
260
261
262
263
264
265
266
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
267

268
269
270
271
272
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

273
    def forward(
274
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
275
276
277
278
279
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
280
            position_ids=position_ids,
281
282
283
284
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
285
286
287

    def generate_token(
        self, batch: CausalLMBatch
288
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
289
290
291
292
293
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
OlivierDehaene's avatar
OlivierDehaene committed
294
            logits, past = self.forward(
295
296
297
298
                batch.input_ids,
                batch.attention_mask,
                batch.position_ids,
                batch.past_key_values,
OlivierDehaene's avatar
OlivierDehaene committed
299
            )
300
301
302
303

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
304
305
        # New values for next forward
        next_batch_input_lengths = []
306
307
308
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
309
        # Metadata
310
311
312
        next_batch_size = 0
        next_batch_max_sequence_length = 0

313
314
        # Results
        generations: List[Generation] = []
315
316
317
318

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
319
            batch.input_lengths,
320
321
322
323
324
325
326
327
328
329
330
331
332
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
333
            all_input_ids,
334
335
        ) in enumerate(iterator):
            # Select next token
OlivierDehaene's avatar
OlivierDehaene committed
336
            tokens, logprobs = next_token_chooser(all_input_ids, logits)
337
            next_token_id = tokens[-1].view(1, 1)
338
339

            # Append next token to all tokens
340
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
341
342
            new_input_length = input_length + 1

343
344
345
346
347
348
349
350
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_text = self.tokenizer.decode(
                next_token_id_squeezed,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=False,
            )
351
352

            # Evaluate stopping criteria
353
            stop, reason = stopping_criteria(
354
355
                next_token_id_squeezed,
                next_token_text,
356
            )
357

358
            if stop:
359
360
361
                # Decode generated tokens
                generated_text = self.decode(
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
362
                )
363
                output_text = request.inputs + generated_text
OlivierDehaene's avatar
OlivierDehaene committed
364

365
366
367
368
369
370
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

371
372
                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
373
374
                )
            else:
375
376
                # Keep request in the batch
                generated_text = None
377
                next_batch_keep_indices.append(i)
378
                next_batch_input_ids.append(next_token_id)
OlivierDehaene's avatar
OlivierDehaene committed
379
                next_batch_all_input_ids.append(all_input_ids)
380
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
381
                next_batch_input_lengths.append(new_input_length)
382
383
384
385
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
            # Prefill
            if stopping_criteria.current_tokens == 1:
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids[1:]
                ).squeeze(1)[-new_input_length:-1].tolist()
                prefill_token_ids = all_input_ids[-new_input_length:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_squeezed,
                next_token_logprob,
                next_token_text,
                generated_text,
            )

            generations.append(generation)

415
416
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
417
            return generations, None
418

OlivierDehaene's avatar
OlivierDehaene committed
419
420
421
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
422
        if len(next_batch_keep_indices) != len(batch):
423
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
424
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
425
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
426
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
427
            next_batch_past_key_values = [
428
                [
429
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
430
431
432
433
434
435
436
437
438
439
440
441
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
442
            next_batch_attention_mask = batch.attention_mask
443
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
444
            next_batch_past_key_values = past
445
446
447
448
449
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

        # Update attention_mask with padding as we added a new token to input_ids
OlivierDehaene's avatar
OlivierDehaene committed
450
        next_batch_attention_mask = torch.cat(
451
            [
OlivierDehaene's avatar
OlivierDehaene committed
452
                next_batch_attention_mask,
453
                next_batch_attention_mask.new_ones(next_batch_size, 1),
454
455
456
457
            ],
            dim=1,
        )

458
459
460
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

461
462
463
464
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
465
            attention_mask=next_batch_attention_mask,
466
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
467
            past_key_values=next_batch_past_key_values,
468
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
469
            input_lengths=next_batch_input_lengths,
470
471
472
473
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
474
            keys_head_dim_last=batch.keys_head_dim_last,
475
        )
476
        return generations, next_batch