causal_lm.py 17.4 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
5
from typing import Optional, Tuple, List, Type
6
7

from text_generation.models import Model
8
from text_generation.models.types import GeneratedText, Batch
9
10
11
12
13
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria


@dataclass
14
class CausalLMBatch(Batch):
15
16
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
17
18
19
20

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
21
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
22
23
24
    past_key_values: Optional[List[Tuple]]

    # All tokens
25
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
26
    all_logprobs: List[Optional[torch.Tensor]]
OlivierDehaene's avatar
OlivierDehaene committed
27
28
29
30
31

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
32
33
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
34
35

    # Metadata used for padding
36
37
38
    size: int
    max_sequence_length: int

39
40
41
    # Past metadata
    keys_head_dim_last: bool = True

42
    def to_pb(self) -> generate_pb2.Batch:
43
44
45
46
47
48
49
50
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
51
52
53
54
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
55
56
57
58
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
59
        input_lengths = []
OlivierDehaene's avatar
OlivierDehaene committed
60
        all_logprobs = []
61
62
63
64

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
65
            input_lengths.append(r.input_length)
66
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
67
            stopping_criterias.append(
68
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
69
            )
OlivierDehaene's avatar
OlivierDehaene committed
70
            all_logprobs.append(None)
71

72
        pad_to_multiple_of = 8 if device.type == "cuda" else None
OlivierDehaene's avatar
OlivierDehaene committed
73
        tokenized_inputs = tokenizer(
74
75
76
77
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
78
            return_token_type_ids=False,
79
        ).to(device)
80
81
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
OlivierDehaene's avatar
OlivierDehaene committed
82
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
83
84
85
86

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
OlivierDehaene's avatar
OlivierDehaene committed
87
88
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
89
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
90
            past_key_values=None,
91
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
92
            all_logprobs=all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
93
            input_lengths=input_lengths,
94
95
96
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
OlivierDehaene's avatar
OlivierDehaene committed
97
            max_sequence_length=max(input_lengths),
98
99
100
101
102
103
104
105
106
107
        )

    @classmethod
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
108
        input_lengths = []
109
        all_input_ids = []
OlivierDehaene's avatar
OlivierDehaene committed
110
        all_logprobs = []
111
112
113
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
114
115
116
        # Batch tensors
        input_ids = None
        attention_mask = None
117
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
118
119
        past_key_values = []

120
121
122
123
124
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
125
            input_lengths.extend(batch.input_lengths)
126
            all_input_ids.extend(batch.all_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
127
            all_logprobs.extend(batch.all_logprobs)
128
129
130
131
132
133
134
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
135
136
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
137

OlivierDehaene's avatar
OlivierDehaene committed
138
139
140
141
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
142
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
143
144
145
146
147
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
148
                attention_mask = batch.attention_mask.new_zeros(
149
150
151
152
                    (total_batch_size, max_sequence_length),
                )

            # We need to slice the attention mask to remove padding from previous steps
OlivierDehaene's avatar
OlivierDehaene committed
153
            attention_mask[
154
                start_index:end_index, -batch.max_sequence_length :
OlivierDehaene's avatar
OlivierDehaene committed
155
            ] = batch.attention_mask[:, -batch.max_sequence_length :]
156

157
158
159
160
161
162
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

OlivierDehaene's avatar
OlivierDehaene committed
163
            for j, past in enumerate(batch.past_key_values):
164
165
                past_keys, past_values = past

166
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
167
168
169
170
171
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

172
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
173

174
                padded_past_values_shape = (
175
176
177
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
178
                    head_dim,
179
180
                )

181
182
183
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
184
                    # seq_length is last for BLOOM
185
                    padded_past_keys_shape = (
186
187
188
                        total_batch_size,
                        num_heads,
                        head_dim,
189
                        max_sequence_length - 1,
190
191
                    )

192
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
193
                if j == len(past_key_values):
194
195
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
196
197
198
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
199
                if batch.keys_head_dim_last:
200
                    past_key_values[j][0][
201
202
203
204
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
205
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
206
                else:
207
                    past_key_values[j][0][
208
209
210
211
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
212
213
214
215
216
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
217
218
219
220
221
222
223

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
224
            attention_mask=attention_mask,
225
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
226
            past_key_values=past_key_values,
227
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
228
            all_logprobs=all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
229
            input_lengths=input_lengths,
230
231
232
233
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
234
            keys_head_dim_last=batches[0].keys_head_dim_last,
235
        )
236
237
238


class CausalLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
239
    def __init__(self, model_name: str, quantize=False):
240
241
242
243
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
244
245
246
            if quantize:
                raise ValueError("quantization is not available on CPU")

247
248
249
250
251
252
253
254
            device = torch.device("cpu")
            dtype = torch.float32

        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
255
            load_in_8bit=quantize,
256
        ).eval()
257
258
259
260
261
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
262

263
264
265
266
267
268
269
270
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
271

272
273
274
275
276
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

277
    def forward(
278
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
279
280
281
282
283
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
284
            position_ids=position_ids,
285
286
287
288
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
289
290
291
292
293
294
295
296
297

    def generate_token(
        self, batch: CausalLMBatch
    ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
OlivierDehaene's avatar
OlivierDehaene committed
298
            logits, past = self.forward(
299
                batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
OlivierDehaene's avatar
OlivierDehaene committed
300
            )
301
302
303
304

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
305
306
        # New values for next forward
        next_batch_input_lengths = []
307
308
        next_batch_input_ids = []
        next_batch_all_input_ids = []
OlivierDehaene's avatar
OlivierDehaene committed
309
        next_batch_all_logprobs = []
310

OlivierDehaene's avatar
OlivierDehaene committed
311
        # Metadata
312
313
314
315
316
317
318
319
320
        next_batch_size = 0
        next_batch_max_sequence_length = 0

        # Finished requests
        generated_texts: List[GeneratedText] = []

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
321
            batch.input_lengths,
322
323
324
325
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
326
            batch.all_logprobs,
327
328
329
330
331
332
333
334
335
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
336
337
            all_input_ids,
            all_logprobs,
338
339
        ) in enumerate(iterator):
            # Select next token
OlivierDehaene's avatar
OlivierDehaene committed
340
341
            tokens, logprobs = next_token_chooser(all_input_ids, logits)
            next_token = tokens[-1].view(1, 1)
342
343

            # Append next token to all tokens
OlivierDehaene's avatar
OlivierDehaene committed
344
345
346
347
348
349
350
351
352
353
            all_input_ids = torch.cat([all_input_ids, next_token])
            new_input_length = input_length + 1

            if all_logprobs is None:
                # logprobs of all prompt tokens (except the first one) and the generated token
                all_logprobs = logprobs.gather(1, all_input_ids[1:])
            else:
                # logprob of the generated token
                next_token_logprob = logprobs[-1, next_token]
                all_logprobs = torch.cat([all_logprobs, next_token_logprob])
354
355

            # Evaluate stopping criteria
356
357
358
359
360
361
            stop, reason = stopping_criteria(
                next_token.squeeze(),
                self.tokenizer.decode(
                    next_token.squeeze(), clean_up_tokenization_spaces=False
                ),
            )
362
            if stop:
363
364
365
                # Decode generated tokens
                generated_text = self.decode(
                    all_input_ids[-stopping_criteria.current_tokens :, 0]
366
                )
367
                output_text = request.inputs + generated_text
OlivierDehaene's avatar
OlivierDehaene committed
368
369
370
371
                # Slice with input_length to remove padding
                token_ids = all_input_ids[-new_input_length:]
                tokens = self.tokenizer.batch_decode(token_ids)
                # Add NaN for the first prompt token
372
                logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
OlivierDehaene's avatar
OlivierDehaene committed
373
374
375
                    1
                ).tolist()

376
377
                # Add to the list of finished generations with the original request
                generated_texts.append(
378
                    GeneratedText(
OlivierDehaene's avatar
OlivierDehaene committed
379
380
381
382
383
384
385
                        request=request,
                        output_text=output_text,
                        generated_tokens=stopping_criteria.current_tokens,
                        tokens=tokens,
                        token_ids=token_ids.squeeze(1).tolist(),
                        logprobs=logprobs,
                        reason=reason,
386
                    )
387
388
389
390
391
                )
            # add to the next batch
            else:
                next_batch_keep_indices.append(i)
                next_batch_input_ids.append(next_token)
OlivierDehaene's avatar
OlivierDehaene committed
392
393
                next_batch_all_input_ids.append(all_input_ids)
                next_batch_all_logprobs.append(all_logprobs)
394
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
395
                next_batch_input_lengths.append(new_input_length)
396
397
398
399
400
401
402
403
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None

OlivierDehaene's avatar
OlivierDehaene committed
404
405
406
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
407
408
        if generated_texts:
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
409
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
410
            next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
411
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
412
            next_batch_past_key_values = [
413
                [
414
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
415
416
417
418
419
420
421
422
423
424
425
426
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
427
            next_batch_attention_mask = batch.attention_mask
428
            next_batch_position_ids = batch.position_ids
OlivierDehaene's avatar
OlivierDehaene committed
429
            next_batch_past_key_values = past
430
431
432
433
434
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

        # Update attention_mask with padding as we added a new token to input_ids
OlivierDehaene's avatar
OlivierDehaene committed
435
        next_batch_attention_mask = torch.cat(
436
            [
OlivierDehaene's avatar
OlivierDehaene committed
437
                next_batch_attention_mask,
438
                next_batch_attention_mask.new_ones(next_batch_size, 1),
439
440
441
442
            ],
            dim=1,
        )

443
444
445
        # Update position_ids
        next_batch_position_ids = next_batch_position_ids[:, -1:] + 1

446
447
448
449
        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
450
            attention_mask=next_batch_attention_mask,
451
            position_ids=next_batch_position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
452
            past_key_values=next_batch_past_key_values,
453
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
454
            all_logprobs=next_batch_all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
455
            input_lengths=next_batch_input_lengths,
456
457
458
459
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
460
            keys_head_dim_last=batch.keys_head_dim_last,
461
462
        )
        return generated_texts, next_batch