causal_lm.py 16.2 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
OlivierDehaene's avatar
OlivierDehaene committed
5
from typing import Optional, Tuple, List, Type
6
7

from text_generation.models import Model
8
from text_generation.models.types import GeneratedText, Batch
9
10
11
12
13
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria


@dataclass
14
class CausalLMBatch(Batch):
15
16
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
17
18
19
20
21
22
23

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    past_key_values: Optional[List[Tuple]]

    # All tokens
24
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
25
    all_logprobs: List[Optional[torch.Tensor]]
OlivierDehaene's avatar
OlivierDehaene committed
26
27
28
29
30

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
31
32
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
33
34

    # Metadata used for padding
35
36
37
    size: int
    max_sequence_length: int

38
39
40
    # Past metadata
    keys_head_dim_last: bool = True

41
    def to_pb(self) -> generate_pb2.Batch:
42
43
44
45
46
47
48
49
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
50
        cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
51
52
53
54
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
55
        input_lengths = []
OlivierDehaene's avatar
OlivierDehaene committed
56
        all_logprobs = []
57
58
59
60

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
61
            input_lengths.append(r.input_length)
62
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
63
            stopping_criterias.append(
64
                StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
65
            )
OlivierDehaene's avatar
OlivierDehaene committed
66
            all_logprobs.append(None)
67

68
        pad_to_multiple_of = 8 if device.type == "cuda" else None
OlivierDehaene's avatar
OlivierDehaene committed
69
        tokenized_inputs = tokenizer(
70
71
72
73
            inputs,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=pad_to_multiple_of,
74
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
75
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
76
77
78
79

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
OlivierDehaene's avatar
OlivierDehaene committed
80
81
82
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            past_key_values=None,
83
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
84
            all_logprobs=all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
85
            input_lengths=input_lengths,
86
87
88
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
OlivierDehaene's avatar
OlivierDehaene committed
89
            max_sequence_length=max(input_lengths),
90
91
92
93
94
95
96
97
98
99
        )

    @classmethod
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
100
        input_lengths = []
101
        all_input_ids = []
OlivierDehaene's avatar
OlivierDehaene committed
102
        all_logprobs = []
103
104
105
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
106
107
108
109
110
        # Batch tensors
        input_ids = None
        attention_mask = None
        past_key_values = []

111
112
113
114
115
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
116
            input_lengths.extend(batch.input_lengths)
117
            all_input_ids.extend(batch.all_input_ids)
OlivierDehaene's avatar
OlivierDehaene committed
118
            all_logprobs.extend(batch.all_logprobs)
119
120
121
122
123
124
125
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
126
127
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
128

OlivierDehaene's avatar
OlivierDehaene committed
129
130
131
132
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
133
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
134
135
136
137
138
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
139
                attention_mask = batch.attention_mask.new_zeros(
140
141
142
143
                    (total_batch_size, max_sequence_length),
                )

            # We need to slice the attention mask to remove padding from previous steps
OlivierDehaene's avatar
OlivierDehaene committed
144
            attention_mask[
145
                start_index:end_index, -batch.max_sequence_length :
OlivierDehaene's avatar
OlivierDehaene committed
146
            ] = batch.attention_mask[:, -batch.max_sequence_length :]
147

OlivierDehaene's avatar
OlivierDehaene committed
148
            for j, past in enumerate(batch.past_key_values):
149
150
                past_keys, past_values = past

151
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
152
153
154
155
156
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

157
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
158

159
                padded_past_values_shape = (
160
161
162
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
163
                    head_dim,
164
165
                )

166
167
168
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
                else:
169
                    # seq_length is last for BLOOM
170
                    padded_past_keys_shape = (
171
172
173
                        total_batch_size,
                        num_heads,
                        head_dim,
174
                        max_sequence_length - 1,
175
176
                    )

177
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
178
                if j == len(past_key_values):
179
180
                    padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
                    padded_past_values = past_values.new_zeros(padded_past_values_shape)
181
182
183
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
184
                if batch.keys_head_dim_last:
185
                    past_key_values[j][0][
186
187
188
189
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
190
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
191
                else:
192
                    past_key_values[j][0][
193
194
195
196
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
197
198
199
200
201
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
202
203
204
205
206
207
208

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
209
210
            attention_mask=attention_mask,
            past_key_values=past_key_values,
211
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
212
            all_logprobs=all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
213
            input_lengths=input_lengths,
214
215
216
217
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
218
            keys_head_dim_last=batches[0].keys_head_dim_last,
219
        )
220
221
222


class CausalLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
223
    def __init__(self, model_name: str, quantize=False):
224
225
226
227
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
228
229
230
            if quantize:
                raise ValueError("quantization is not available on CPU")

231
232
233
234
235
236
237
238
            device = torch.device("cpu")
            dtype = torch.float32

        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
239
            load_in_8bit=quantize,
240
        ).eval()
241
242
243
244
245
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
246

247
248
249
250
251
252
253
254
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
255
256

    def forward(
257
        self, input_ids, attention_mask, past_key_values: Optional = None
258
259
260
261
262
263
264
265
266
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
267
268
269
270
271
272
273
274
275

    def generate_token(
        self, batch: CausalLMBatch
    ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
OlivierDehaene's avatar
OlivierDehaene committed
276
277
278
            logits, past = self.forward(
                batch.input_ids, batch.attention_mask, batch.past_key_values
            )
279
280
281
282

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
283
284
        # New values for next forward
        next_batch_input_lengths = []
285
286
        next_batch_input_ids = []
        next_batch_all_input_ids = []
OlivierDehaene's avatar
OlivierDehaene committed
287
        next_batch_all_logprobs = []
288

OlivierDehaene's avatar
OlivierDehaene committed
289
        # Metadata
290
291
292
293
294
295
296
297
298
        next_batch_size = 0
        next_batch_max_sequence_length = 0

        # Finished requests
        generated_texts: List[GeneratedText] = []

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
299
            batch.input_lengths,
300
301
302
303
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
304
            batch.all_logprobs,
305
306
307
308
309
310
311
312
313
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
314
315
            all_input_ids,
            all_logprobs,
316
317
        ) in enumerate(iterator):
            # Select next token
OlivierDehaene's avatar
OlivierDehaene committed
318
319
            tokens, logprobs = next_token_chooser(all_input_ids, logits)
            next_token = tokens[-1].view(1, 1)
320
321

            # Append next token to all tokens
OlivierDehaene's avatar
OlivierDehaene committed
322
323
324
325
326
327
328
329
330
331
            all_input_ids = torch.cat([all_input_ids, next_token])
            new_input_length = input_length + 1

            if all_logprobs is None:
                # logprobs of all prompt tokens (except the first one) and the generated token
                all_logprobs = logprobs.gather(1, all_input_ids[1:])
            else:
                # logprob of the generated token
                next_token_logprob = logprobs[-1, next_token]
                all_logprobs = torch.cat([all_logprobs, next_token_logprob])
332
333

            # Evaluate stopping criteria
334
335
336
337
338
339
            stop, reason = stopping_criteria(
                next_token.squeeze(),
                self.tokenizer.decode(
                    next_token.squeeze(), clean_up_tokenization_spaces=False
                ),
            )
340
            if stop:
341
                # Decode all tokens
OlivierDehaene's avatar
OlivierDehaene committed
342
                output_text = self.tokenizer.decode(
343
344
                    all_input_ids.squeeze(-1), skip_special_tokens=True,
                    cleanup_tokenization_spaces=False
345
                )
OlivierDehaene's avatar
OlivierDehaene committed
346
347
348
349
350
351
352
353
                # Slice with input_length to remove padding
                token_ids = all_input_ids[-new_input_length:]
                tokens = self.tokenizer.batch_decode(token_ids)
                # Add NaN for the first prompt token
                logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
                    1
                ).tolist()

354
355
                # Add to the list of finished generations with the original request
                generated_texts.append(
356
                    GeneratedText(
OlivierDehaene's avatar
OlivierDehaene committed
357
358
359
360
361
362
363
                        request=request,
                        output_text=output_text,
                        generated_tokens=stopping_criteria.current_tokens,
                        tokens=tokens,
                        token_ids=token_ids.squeeze(1).tolist(),
                        logprobs=logprobs,
                        reason=reason,
364
                    )
365
366
367
368
369
                )
            # add to the next batch
            else:
                next_batch_keep_indices.append(i)
                next_batch_input_ids.append(next_token)
OlivierDehaene's avatar
OlivierDehaene committed
370
371
                next_batch_all_input_ids.append(all_input_ids)
                next_batch_all_logprobs.append(all_logprobs)
372
                next_batch_size += 1
OlivierDehaene's avatar
OlivierDehaene committed
373
                next_batch_input_lengths.append(new_input_length)
374
375
376
377
378
379
380
381
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None

OlivierDehaene's avatar
OlivierDehaene committed
382
383
384
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
385
386
        if generated_texts:
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
387
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
388
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
389
            next_batch_past_key_values = [
390
                [
391
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
392
393
394
395
396
397
398
399
400
401
402
403
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
404
405
            next_batch_attention_mask = batch.attention_mask
            next_batch_past_key_values = past
406
407
408
409
410
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

        # Update attention_mask with padding as we added a new token to input_ids
OlivierDehaene's avatar
OlivierDehaene committed
411
        next_batch_attention_mask = torch.cat(
412
            [
OlivierDehaene's avatar
OlivierDehaene committed
413
                next_batch_attention_mask,
414
                next_batch_attention_mask.new_ones(next_batch_size, 1),
415
416
417
418
419
420
421
422
            ],
            dim=1,
        )

        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
423
424
            attention_mask=next_batch_attention_mask,
            past_key_values=next_batch_past_key_values,
425
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
426
            all_logprobs=next_batch_all_logprobs,
OlivierDehaene's avatar
OlivierDehaene committed
427
            input_lengths=next_batch_input_lengths,
428
429
430
431
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
432
            keys_head_dim_last=batch.keys_head_dim_last,
433
434
        )
        return generated_texts, next_batch