model.py 17 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.distributed

from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights

from bloom_inference.pb import generate_pb2
from bloom_inference.shard_model import shard_model, match_suffix
from bloom_inference.utils import (
    StoppingCriteria,
    NextTokenChooser,
    initialize_torch_distributed,
    set_default_dtype,
)

torch.manual_seed(0)


@dataclass
class Batch:
    batch_id: int
Olivier Dehaene's avatar
Olivier Dehaene committed
26
    requests: List[generate_pb2.Request]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28
29
30
    input_ids: Dict[str, torch.Tensor]
    all_input_ids: List[torch.Tensor]
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
33
34
35
36
37
38
39
40
    size: int
    max_sequence_length: int

    def to_pb(self):
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
            max_sequence_length=self.max_sequence_length,
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
    def from_pb(
        cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
    ) -> "Batch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
Olivier Dehaene committed
49
        input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50
51
52
53

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
Olivier Dehaene's avatar
Olivier Dehaene committed
54
            input_lengths.append(r.input_length)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
55
56
57
58
59
60
61
62
63
64
65
            next_token_choosers.append(
                NextTokenChooser(
                    temperature=r.parameters.temperature,
                    top_k=r.parameters.top_k,
                    top_p=r.parameters.top_p,
                    do_sample=r.parameters.do_sample,
                )
            )
            stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))

        input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
Olivier Dehaene's avatar
Olivier Dehaene committed
66
67
68
69
70
71
72
        # Remove padding from all_input_ids
        all_input_ids = [
            input_ids.squeeze(0)[-length:].unsqueeze(-1)
            for length, input_ids in zip(
                input_lengths, input_ids["input_ids"].split(1, dim=0)
            )
        ]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
73
74

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
75
76
77
78
79
80
81
82
            batch_id=pb.id,
            requests=pb.requests,
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
            max_sequence_length=pb.max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
83
84
85
        )

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
86
87
88
89
    def concatenate(cls, batches: List["Batch"]) -> "Batch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
90

Olivier Dehaene's avatar
Olivier Dehaene committed
91
        # Batch attributes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92
        input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
Olivier Dehaene's avatar
Olivier Dehaene committed
93
        requests = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
94
95
96
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
Olivier Dehaene committed
97
98
99

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
        start_index = 0
Olivier Dehaene's avatar
Olivier Dehaene committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.input_ids["input_ids"].shape[1] > 1:
                raise ValueError("Batch input_ids should be of shape (batch_size, 1)")

            # Initialize tensors
            if i == 0:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
116
117
                input_ids["input_ids"] = torch.empty(
                    (total_batch_size, 1),
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
                    dtype=batch.input_ids["input_ids"].dtype,
                    device=batch.input_ids["input_ids"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
120
121
122
                )
                input_ids["attention_mask"] = torch.zeros(
                    (total_batch_size, max_sequence_length),
Olivier Dehaene's avatar
Olivier Dehaene committed
123
124
                    dtype=batch.input_ids["attention_mask"].dtype,
                    device=batch.input_ids["attention_mask"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
125
126
                )

Olivier Dehaene's avatar
Olivier Dehaene committed
127
128
129
130
131
            # input_ids["input_ids"] is always of shape [batch_size, 1]
            # We do not need to pad it
            input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]

            # We need to slice the attention mask to remove padding from previous steps
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
132
            input_ids["attention_mask"][
Olivier Dehaene's avatar
Olivier Dehaene committed
133
134
                start_index:end_index, -batch.max_sequence_length :
            ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
135

Olivier Dehaene's avatar
Olivier Dehaene committed
136
            for j, past in enumerate(batch.input_ids["past_key_values"]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
137
138
139
140
141
                past_keys = past[0]
                past_values = past[1]

                _, head_dim, padded_sequence_length = past_keys.shape

Olivier Dehaene's avatar
Olivier Dehaene committed
142
                # Reshape the tensors to make slicing easier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
143
                past_keys = past_keys.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
144
                    batch.size, -1, head_dim, padded_sequence_length
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
145
146
                )
                past_values = past_values.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
147
                    batch.size, -1, padded_sequence_length, head_dim
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
149
150
                )
                num_heads = past_keys.shape[1]

Olivier Dehaene's avatar
Olivier Dehaene committed
151
152
                # Initialize tensors
                # This will run only once per layer
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                if j == len(input_ids["past_key_values"]):
                    padded_past_keys = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            head_dim,
                            max_sequence_length - 1,
                        ),
                        dtype=past_keys.dtype,
                        device=past_keys.device,
                    )
                    padded_past_values = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            max_sequence_length - 1,
                            head_dim,
                        ),
                        dtype=past_values.dtype,
                        device=past_values.device,
                    )
                    input_ids["past_key_values"].append(
                        [padded_past_keys, padded_past_values]
                    )

Olivier Dehaene's avatar
Olivier Dehaene committed
178
                # We slice the past keys and values to remove the padding from previous batches
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
179
                input_ids["past_key_values"][j][0][
Olivier Dehaene's avatar
Olivier Dehaene committed
180
181
                    start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
                ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
182
183

                input_ids["past_key_values"][j][1][
Olivier Dehaene's avatar
Olivier Dehaene committed
184
185
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
186

Olivier Dehaene's avatar
Olivier Dehaene committed
187
188
                # If we are on the last batch, we need to reshape the tensors
                if (i + 1) == len(batches):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
189
190
191
192
193
194
195
                    input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
                        j
                    ][0].view(total_batch_size * num_heads, head_dim, -1)
                    input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
                        j
                    ][1].view(total_batch_size * num_heads, -1, head_dim)

Olivier Dehaene's avatar
Olivier Dehaene committed
196
            start_index += batch.size
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
197
198

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
199
200
201
202
203
204
205
206
            batch_id=batches[0].batch_id,
            requests=requests,
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
207
208
209
210
        )


@dataclass
Olivier Dehaene's avatar
Olivier Dehaene committed
211
212
class GeneratedText:
    request: generate_pb2.Request
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
213
214
    output: str

Olivier Dehaene's avatar
Olivier Dehaene committed
215
216
    def to_pb(self) -> generate_pb2.GeneratedText:
        return generate_pb2.GeneratedText(request=self.request, output=self.output)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241


class BLOOM:
    def __init__(self, model_name: str):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = (
            AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
        )
        self.num_heads = self.model.base_model.num_heads

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

    def generate_token(
Olivier Dehaene's avatar
Olivier Dehaene committed
242
243
        self, batch: Batch
    ) -> Tuple[List[GeneratedText], Optional[Batch]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
244
245
246
247
        with torch.no_grad():
            outputs = self.forward(**batch.input_ids)

        # List of indices to cache
Olivier Dehaene's avatar
Olivier Dehaene committed
248
249
250
251
252
253
        next_batch_keep_indices = []
        next_batch_past_keep_indices = []

        # New input_ids for next forward
        next_batch_input_ids = []
        next_batch_all_input_ids = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
254

Olivier Dehaene's avatar
Olivier Dehaene committed
255
256
        next_batch_size = 0
        next_batch_max_sequence_length = 0
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
257
258

        # Finished requests
Olivier Dehaene's avatar
Olivier Dehaene committed
259
        generated_texts: List[GeneratedText] = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
260
261
262

        # Zipped iterator
        iterator = zip(
Olivier Dehaene's avatar
Olivier Dehaene committed
263
            batch.requests,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
264
265
266
267
268
269
270
271
            outputs.logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
Olivier Dehaene's avatar
Olivier Dehaene committed
272
273
274
275
276
            request,
            logits,
            next_token_chooser,
            stopping_criteria,
            all_tokens,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
277
278
279
280
281
282
283
284
285
286
287
288
289
        ) in enumerate(iterator):
            # Select next token
            next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])

            # Append next token to all tokens
            all_tokens = torch.cat([all_tokens, next_token])

            # Evaluate stopping criteria
            if stopping_criteria(all_tokens):
                # Decode all tokens
                output = self.tokenizer.decode(
                    all_tokens.squeeze(-1), skip_special_tokens=True
                )
Olivier Dehaene's avatar
Olivier Dehaene committed
290
291
292
                # Add to the list of finished generations with the original request
                generated_texts.append(GeneratedText(request, output))
            # add to the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
293
            else:
Olivier Dehaene's avatar
Olivier Dehaene committed
294
295
296
297
298
299
300
301
302
303
304
305
                next_batch_keep_indices.append(i)
                # past_key_values is of shape [batch_size * num_heads, ...]
                # so we need to take into account the `num_heads` stride here
                next_batch_past_keep_indices.extend(
                    [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
                )
                next_batch_input_ids.append(next_token)
                next_batch_all_input_ids.append(all_tokens)
                next_batch_size += 1
                next_batch_max_sequence_length = max(
                    next_batch_max_sequence_length, len(all_tokens)
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
306

Olivier Dehaene's avatar
Olivier Dehaene committed
307
308
309
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
310
311

        # If we finished at least one generation
Olivier Dehaene's avatar
Olivier Dehaene committed
312
313
        next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
        if generated_texts:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
314
            # Apply indices to attention mask, past key values and other items that need to be cached
Olivier Dehaene's avatar
Olivier Dehaene committed
315
316
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
                next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
317
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
318
319
320
321
322
            next_batch_input_ids["past_key_values"] = [
                (
                    keys[next_batch_past_keep_indices],
                    values[next_batch_past_keep_indices],
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
323
324
                for keys, values in outputs["past_key_values"]
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
325
326
327
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
328
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
329
330
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
331
332
            ]
        else:
Olivier Dehaene's avatar
Olivier Dehaene committed
333
334
335
336
337
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
            next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
338
339

        # Update attention_mask with padding as we added a new token to input_ids
Olivier Dehaene's avatar
Olivier Dehaene committed
340
        next_batch_input_ids["attention_mask"] = torch.cat(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
341
            [
Olivier Dehaene's avatar
Olivier Dehaene committed
342
343
                next_batch_input_ids["attention_mask"],
                torch.ones((next_batch_size, 1)).to(self.device),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
344
345
346
347
            ],
            dim=1,
        )

Olivier Dehaene's avatar
Olivier Dehaene committed
348
349
350
351
352
353
354
355
356
        next_batch = Batch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
            input_ids=next_batch_input_ids,
            all_input_ids=next_batch_all_input_ids,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
357
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
358
        return generated_texts, next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452


class BLOOMSharded(BLOOM):
    def __init__(self, model_name: str, shard_directory: Path):
        super(BLOOM, self).__init__()
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
            self.device = torch.device(f"cuda:{self.rank}")
            dtype = torch.bfloat16
        else:
            self.device = torch.device("cpu")
            dtype = torch.float32

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

        # shard state_dict
        if self.master:
            # TODO @thomasw21 do some caching
            shard_state_dict_paths = shard_model(
                model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
            )
            shard_state_dict_paths = [
                str(path.absolute()) for path in shard_state_dict_paths
            ]
        else:
            shard_state_dict_paths = [None] * self.world_size

        torch.distributed.broadcast_object_list(
            shard_state_dict_paths, src=0, group=self.process_group
        )
        shard_state_dict_path = shard_state_dict_paths[self.rank]

        config = AutoConfig.from_pretrained(
            model_name, slow_but_exact=False, tp_parallel=True
        )
        config.pad_token_id = 3

        # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
        # in PyTorch 1.12 and later.
        torch.backends.cuda.matmul.allow_tf32 = True

        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

        with set_default_dtype(dtype):
            with no_init_weights():
                # we can probably set the device to `meta` here?
                model = AutoModelForCausalLM.from_config(config).to(dtype)

        torch.distributed.barrier(group=self.process_group)
        # print_rank_0(f"Initialized model")
        state_dict = torch.load(shard_state_dict_path)
        # TODO @thomasw21: HACK in order to transpose all weight prior
        for key in state_dict.keys():
            do_transpose = False
            if not match_suffix(key, "weight"):
                continue

            for potential_suffix in [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "dense_h_to_4h.weight",
                "dense_4h_to_h.weight",
            ]:
                if match_suffix(key, potential_suffix):
                    do_transpose = True

            if do_transpose:
                state_dict[key] = state_dict[key].transpose(1, 0).contiguous()

        model.load_state_dict(state_dict)
        self.model = model.to(self.device).eval()
        self.num_heads = config.n_head // self.process_group.size()
        torch.distributed.barrier(group=self.process_group)

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

        logits_shard = outputs.logits[:, -1, :].contiguous()

        batch_size, vocab_shard_size = logits_shard.shape
        vocab_size = self.world_size * vocab_shard_size
        logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
        logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)

        outputs.logits = logits
        return outputs