causal_lm.py 15 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from transformers import AutoTokenizer, AutoModelForCausalLM
OlivierDehaene's avatar
OlivierDehaene committed
5
from typing import Optional, Tuple, List, Type
6
7

from text_generation.models import Model
8
9
10
11
12
13
14
15
16
from text_generation.models.types import GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria


@dataclass
class CausalLMBatch:
    batch_id: int
    requests: List[generate_pb2.Request]
OlivierDehaene's avatar
OlivierDehaene committed
17
18
19
20
21
22
23

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    past_key_values: Optional[List[Tuple]]

    # All tokens
24
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
25
26
27
28
29

    # Lengths of all generations present in the batch
    input_lengths: List[int]

    # Generation helpers
30
31
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
32
33

    # Metadata used for padding
34
35
36
    size: int
    max_sequence_length: int

37
38
39
    # Past metadata
    keys_head_dim_last: bool = True

40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def to_pb(self):
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
        )

    @classmethod
    def from_pb(
        cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
OlivierDehaene's avatar
OlivierDehaene committed
54
        input_lengths = []
55
56
57
58

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
OlivierDehaene's avatar
OlivierDehaene committed
59
            input_lengths.append(r.input_length)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            next_token_choosers.append(
                NextTokenChooser(
                    temperature=r.parameters.temperature,
                    top_k=r.parameters.top_k,
                    top_p=r.parameters.top_p,
                    do_sample=r.parameters.do_sample,
                )
            )
            stopping_criterias.append(
                StoppingCriteria(
                    eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
                )
            )

74
        pad_to_multiple_of = 8 if "gpu" in str(device) else None
OlivierDehaene's avatar
OlivierDehaene committed
75
        tokenized_inputs = tokenizer(
76
            inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
77
        ).to(device)
OlivierDehaene's avatar
OlivierDehaene committed
78
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
79
80
81
82

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
OlivierDehaene's avatar
OlivierDehaene committed
83
84
85
            input_ids=tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            past_key_values=None,
86
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
87
            input_lengths=input_lengths,
88
89
90
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
OlivierDehaene's avatar
OlivierDehaene committed
91
            max_sequence_length=max(input_lengths),
92
93
94
95
96
97
98
99
100
101
        )

    @classmethod
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)

        # Batch attributes
        requests = []
OlivierDehaene's avatar
OlivierDehaene committed
102
        input_lengths = []
103
104
105
106
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []

OlivierDehaene's avatar
OlivierDehaene committed
107
108
109
110
111
        # Batch tensors
        input_ids = None
        attention_mask = None
        past_key_values = []

112
113
114
115
116
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
117
            input_lengths.extend(batch.input_lengths)
118
119
120
121
122
123
124
125
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
OlivierDehaene's avatar
OlivierDehaene committed
126
            if batch.input_ids.shape[1] > 1:
127
128
                raise ValueError("Batch input_ids should be of shape (batch_size, 1)")

OlivierDehaene's avatar
OlivierDehaene committed
129
130
131
132
133
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
                input_ids = torch.empty(
134
                    (total_batch_size, 1),
OlivierDehaene's avatar
OlivierDehaene committed
135
136
                    dtype=batch.input_ids.dtype,
                    device=batch.input_ids.device,
137
                )
OlivierDehaene's avatar
OlivierDehaene committed
138
139
140
141
142
143
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
                attention_mask = torch.zeros(
144
                    (total_batch_size, max_sequence_length),
OlivierDehaene's avatar
OlivierDehaene committed
145
146
                    dtype=batch.attention_mask.dtype,
                    device=batch.attention_mask.device,
147
148
149
                )

            # We need to slice the attention mask to remove padding from previous steps
OlivierDehaene's avatar
OlivierDehaene committed
150
            attention_mask[
151
                start_index:end_index, -batch.max_sequence_length :
OlivierDehaene's avatar
OlivierDehaene committed
152
            ] = batch.attention_mask[:, -batch.max_sequence_length :]
153

OlivierDehaene's avatar
OlivierDehaene committed
154
            for j, past in enumerate(batch.past_key_values):
155
156
                past_keys, past_values = past

157
                # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
158
159
160
161
162
                # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
                # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
                past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
                past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])

163
                _, num_heads, padded_sequence_length, head_dim = past_values.shape
164

165
                padded_past_values_shape = (
166
167
168
                    total_batch_size,
                    num_heads,
                    max_sequence_length - 1,
169
                    head_dim,
170
171
                )

172
173
                if batch.keys_head_dim_last:
                    padded_past_keys_shape = padded_past_values_shape
174
                # seq_length is last for BLOOM
175
                else:
176
                    padded_past_keys_shape = (
177
178
179
                        total_batch_size,
                        num_heads,
                        head_dim,
180
                        max_sequence_length - 1,
181
182
                    )

183
                # This will run only once per layer
OlivierDehaene's avatar
OlivierDehaene committed
184
                if j == len(past_key_values):
185
186
187
188
189
190
191
192
193
194
195
196
197
                    padded_past_keys = torch.zeros(
                        padded_past_keys_shape,
                        dtype=past_keys.dtype,
                        device=past_keys.device,
                    )
                    padded_past_values = torch.zeros(
                        padded_past_values_shape,
                        dtype=past_values.dtype,
                        device=past_values.device,
                    )
                    past_key_values.append((padded_past_keys, padded_past_values))

                # We slice the past keys and values to remove the padding from previous batches
198
                if batch.keys_head_dim_last:
199
                    past_key_values[j][0][
200
201
202
203
                        start_index:end_index,
                        :,
                        -(batch.max_sequence_length - 1) :,
                        :,
204
                    ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
205
                else:
206
                    past_key_values[j][0][
207
208
209
210
                        start_index:end_index,
                        :,
                        :,
                        -(batch.max_sequence_length - 1) :,
211
212
213
214
215
                    ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]

                past_key_values[j][1][
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
216
217
218
219
220
221
222

            start_index += batch.size

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
223
224
            attention_mask=attention_mask,
            past_key_values=past_key_values,
225
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
226
            input_lengths=input_lengths,
227
228
229
230
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
231
            keys_head_dim_last=batches[0].keys_head_dim_last,
232
        )
233
234
235


class CausalLM(Model):
OlivierDehaene's avatar
OlivierDehaene committed
236
    def __init__(self, model_name: str, quantize=False):
237
238
239
240
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
        else:
241
242
243
            if quantize:
                raise ValueError("quantization is not available on CPU")

244
245
246
247
248
249
250
251
            device = torch.device("cpu")
            dtype = torch.float32

        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
OlivierDehaene's avatar
OlivierDehaene committed
252
            load_in_8bit=quantize,
253
        ).eval()
254
255
256
257
258
        tokenizer.pad_token_id = (
            self.model.config.pad_token_id
            if self.model.config.pad_token_id is not None
            else self.model.config.eos_token_id
        )
259

260
261
262
263
264
265
266
267
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
268
269

    def forward(
270
        self, input_ids, attention_mask, past_key_values: Optional = None
271
272
273
274
275
276
277
278
279
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values
280
281
282
283
284
285
286
287
288

    def generate_token(
        self, batch: CausalLMBatch
    ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        context_manager = (
            torch.no_grad if self.device.type == "cpu" else torch.inference_mode
        )
        with context_manager():
OlivierDehaene's avatar
OlivierDehaene committed
289
290
291
            logits, past = self.forward(
                batch.input_ids, batch.attention_mask, batch.past_key_values
            )
292
293
294
295

        # List of indices to cache
        next_batch_keep_indices = []

OlivierDehaene's avatar
OlivierDehaene committed
296
297
        # New values for next forward
        next_batch_input_lengths = []
298
299
300
        next_batch_input_ids = []
        next_batch_all_input_ids = []

OlivierDehaene's avatar
OlivierDehaene committed
301
        # Metadata
302
303
304
305
306
307
308
309
310
        next_batch_size = 0
        next_batch_max_sequence_length = 0

        # Finished requests
        generated_texts: List[GeneratedText] = []

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
311
            batch.input_lengths,
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            logits,
            next_token_chooser,
            stopping_criteria,
            all_tokens,
        ) in enumerate(iterator):
            # Select next token
            next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])

            # Append next token to all tokens
            all_tokens = torch.cat([all_tokens, next_token])

            # Evaluate stopping criteria
            if stopping_criteria(all_tokens):
                # Decode all tokens
                output = self.tokenizer.decode(
                    all_tokens.squeeze(-1), skip_special_tokens=True
                )
                # Add to the list of finished generations with the original request
                generated_texts.append(
                    GeneratedText(request, output, stopping_criteria.current_tokens)
                )
            # add to the next batch
            else:
                next_batch_keep_indices.append(i)
                next_batch_input_ids.append(next_token)
                next_batch_all_input_ids.append(all_tokens)
                next_batch_size += 1
                new_input_length = input_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
350
                next_batch_input_lengths.append(new_input_length)
351
352
353
354
355
356
357
358
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, new_input_length
                )

        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None

OlivierDehaene's avatar
OlivierDehaene committed
359
360
361
        next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
        # If we finished at least one generation, we need to evict the indices of the generations that finished
        # from the values of the next batch
362
363
        if generated_texts:
            # Apply indices to attention mask, past key values and other items that need to be cached
OlivierDehaene's avatar
OlivierDehaene committed
364
            next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
365
            # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
OlivierDehaene's avatar
OlivierDehaene committed
366
            next_batch_past_key_values = [
367
                [
368
                    t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
369
370
371
372
373
374
375
376
377
378
379
380
                    for t in layer
                ]
                for layer in past
            ]
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
            ]
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
            ]
        else:
OlivierDehaene's avatar
OlivierDehaene committed
381
382
            next_batch_attention_mask = batch.attention_mask
            next_batch_past_key_values = past
383
384
385
386
387
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias

        # Update attention_mask with padding as we added a new token to input_ids
OlivierDehaene's avatar
OlivierDehaene committed
388
        next_batch_attention_mask = torch.cat(
389
            [
OlivierDehaene's avatar
OlivierDehaene committed
390
                next_batch_attention_mask,
391
                next_batch_attention_mask.new_ones(next_batch_size, 1),
392
393
394
395
396
397
398
399
            ],
            dim=1,
        )

        next_batch = CausalLMBatch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
400
401
            attention_mask=next_batch_attention_mask,
            past_key_values=next_batch_past_key_values,
402
            all_input_ids=next_batch_all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
403
            input_lengths=next_batch_input_lengths,
404
405
406
407
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
408
            keys_head_dim_last=batch.keys_head_dim_last,
409
410
        )
        return generated_texts, next_batch