model.py 19.3 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
import torch
import torch.distributed

from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

Nicolas Patry's avatar
Nicolas Patry committed
7
8
from accelerate import init_empty_weights
from safetensors import safe_open
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
14
from transformers.models.bloom.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
16
17
18
19
20

from bloom_inference.pb import generate_pb2
from bloom_inference.utils import (
    StoppingCriteria,
    NextTokenChooser,
    initialize_torch_distributed,
Nicolas Patry's avatar
Nicolas Patry committed
21
22
    weight_files,
    download_weights
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
25
26
27
28
29
30
)

torch.manual_seed(0)


@dataclass
class Batch:
    batch_id: int
Olivier Dehaene's avatar
Olivier Dehaene committed
31
    requests: List[generate_pb2.Request]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
    all_input_lengths: List[int]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
34
35
36
    input_ids: Dict[str, torch.Tensor]
    all_input_ids: List[torch.Tensor]
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Olivier Dehaene's avatar
Olivier Dehaene committed
37
38
39
40
41
42
43
44
45
46
    size: int
    max_sequence_length: int

    def to_pb(self):
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
            max_sequence_length=self.max_sequence_length,
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
49
50
    def from_pb(
        cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
51
52
53
54
    ) -> "Batch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
        all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
57
58
59

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
            all_input_lengths.append(r.input_length)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
61
62
63
64
65
66
67
68
69
70
            next_token_choosers.append(
                NextTokenChooser(
                    temperature=r.parameters.temperature,
                    top_k=r.parameters.top_k,
                    top_p=r.parameters.top_p,
                    do_sample=r.parameters.do_sample,
                )
            )
            stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))

Nicolas Patry's avatar
Nicolas Patry committed
71
        input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
72
        all_input_ids = input_ids["input_ids"].unsqueeze(-1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
73
74

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
75
76
            batch_id=pb.id,
            requests=pb.requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
            all_input_lengths=all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
78
79
80
81
82
83
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
            max_sequence_length=pb.max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
84
85
86
        )

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
87
88
89
90
    def concatenate(cls, batches: List["Batch"]) -> "Batch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91

Olivier Dehaene's avatar
Olivier Dehaene committed
92
        # Batch attributes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
        input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
Olivier Dehaene's avatar
Olivier Dehaene committed
94
        requests = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
95
        all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
96
97
98
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
Olivier Dehaene committed
99
100
101

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
102
        start_index = 0
Olivier Dehaene's avatar
Olivier Dehaene committed
103
104
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
            all_input_lengths.extend(batch.all_input_lengths)
Olivier Dehaene's avatar
Olivier Dehaene committed
106
107
108
109
110
111
112
113
114
115
116
117
118
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.input_ids["input_ids"].shape[1] > 1:
                raise ValueError("Batch input_ids should be of shape (batch_size, 1)")

            # Initialize tensors
            if i == 0:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
119
120
                input_ids["input_ids"] = torch.empty(
                    (total_batch_size, 1),
Olivier Dehaene's avatar
Olivier Dehaene committed
121
122
                    dtype=batch.input_ids["input_ids"].dtype,
                    device=batch.input_ids["input_ids"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
123
124
125
                )
                input_ids["attention_mask"] = torch.zeros(
                    (total_batch_size, max_sequence_length),
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
                    dtype=batch.input_ids["attention_mask"].dtype,
                    device=batch.input_ids["attention_mask"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
128
129
                )

Olivier Dehaene's avatar
Olivier Dehaene committed
130
131
132
133
134
            # input_ids["input_ids"] is always of shape [batch_size, 1]
            # We do not need to pad it
            input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]

            # We need to slice the attention mask to remove padding from previous steps
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
135
            input_ids["attention_mask"][
Olivier Dehaene's avatar
Olivier Dehaene committed
136
137
                start_index:end_index, -batch.max_sequence_length :
            ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
138

Olivier Dehaene's avatar
Olivier Dehaene committed
139
            for j, past in enumerate(batch.input_ids["past_key_values"]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
140
141
142
143
144
                past_keys = past[0]
                past_values = past[1]

                _, head_dim, padded_sequence_length = past_keys.shape

Olivier Dehaene's avatar
Olivier Dehaene committed
145
                # Reshape the tensors to make slicing easier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
146
                past_keys = past_keys.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
147
                    batch.size, -1, head_dim, padded_sequence_length
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
149
                )
                past_values = past_values.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
150
                    batch.size, -1, padded_sequence_length, head_dim
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
151
152
153
                )
                num_heads = past_keys.shape[1]

Olivier Dehaene's avatar
Olivier Dehaene committed
154
155
                # Initialize tensors
                # This will run only once per layer
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
                if j == len(input_ids["past_key_values"]):
                    padded_past_keys = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            head_dim,
                            max_sequence_length - 1,
                        ),
                        dtype=past_keys.dtype,
                        device=past_keys.device,
                    )
                    padded_past_values = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            max_sequence_length - 1,
                            head_dim,
                        ),
                        dtype=past_values.dtype,
                        device=past_values.device,
                    )
                    input_ids["past_key_values"].append(
                        [padded_past_keys, padded_past_values]
                    )

Olivier Dehaene's avatar
Olivier Dehaene committed
181
                # We slice the past keys and values to remove the padding from previous batches
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
182
                input_ids["past_key_values"][j][0][
Olivier Dehaene's avatar
Olivier Dehaene committed
183
184
                    start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
                ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185
186

                input_ids["past_key_values"][j][1][
Olivier Dehaene's avatar
Olivier Dehaene committed
187
188
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
189

Olivier Dehaene's avatar
Olivier Dehaene committed
190
191
                # If we are on the last batch, we need to reshape the tensors
                if (i + 1) == len(batches):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
192
193
194
195
196
197
198
                    input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
                        j
                    ][0].view(total_batch_size * num_heads, head_dim, -1)
                    input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
                        j
                    ][1].view(total_batch_size * num_heads, -1, head_dim)

Olivier Dehaene's avatar
Olivier Dehaene committed
199
            start_index += batch.size
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
200
201

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
202
203
            batch_id=batches[0].batch_id,
            requests=requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
204
            all_input_lengths=all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
205
206
207
208
209
210
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
211
212
213
214
        )


@dataclass
Olivier Dehaene's avatar
Olivier Dehaene committed
215
216
class GeneratedText:
    request: generate_pb2.Request
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
217
218
    output: str

Olivier Dehaene's avatar
Olivier Dehaene committed
219
220
    def to_pb(self) -> generate_pb2.GeneratedText:
        return generate_pb2.GeneratedText(request=self.request, output=self.output)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
221
222
223
224
225
226


class BLOOM:
    def __init__(self, model_name: str):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
Olivier Dehaene's avatar
Olivier Dehaene committed
227
            dtype = torch.bfloat16
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
228
229
        else:
            self.device = torch.device("cpu")
Olivier Dehaene's avatar
Olivier Dehaene committed
230
            dtype = torch.float32
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
231
232
233

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = (
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
234
235
236
237
            AutoModelForCausalLM.from_pretrained(model_name)
            .eval()
            .to(self.device)
            .to(dtype)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
238
239
240
241
242
243
244
245
246
247
248
249
250
        )
        self.num_heads = self.model.base_model.num_heads

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

    def generate_token(
Olivier Dehaene's avatar
Olivier Dehaene committed
251
252
        self, batch: Batch
    ) -> Tuple[List[GeneratedText], Optional[Batch]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
253
254
255
256
        with torch.no_grad():
            outputs = self.forward(**batch.input_ids)

        # List of indices to cache
Olivier Dehaene's avatar
Olivier Dehaene committed
257
258
259
260
261
262
        next_batch_keep_indices = []
        next_batch_past_keep_indices = []

        # New input_ids for next forward
        next_batch_input_ids = []
        next_batch_all_input_ids = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
263
        next_all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
264

Olivier Dehaene's avatar
Olivier Dehaene committed
265
266
        next_batch_size = 0
        next_batch_max_sequence_length = 0
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
267
268

        # Finished requests
Olivier Dehaene's avatar
Olivier Dehaene committed
269
        generated_texts: List[GeneratedText] = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
270
271
272

        # Zipped iterator
        iterator = zip(
Olivier Dehaene's avatar
Olivier Dehaene committed
273
            batch.requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
274
            batch.all_input_lengths,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
275
276
277
278
279
280
281
282
            outputs.logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
Olivier Dehaene's avatar
Olivier Dehaene committed
283
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
284
            input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
285
286
287
288
            logits,
            next_token_chooser,
            stopping_criteria,
            all_tokens,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
289
290
291
292
293
294
295
296
297
298
299
300
301
        ) in enumerate(iterator):
            # Select next token
            next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])

            # Append next token to all tokens
            all_tokens = torch.cat([all_tokens, next_token])

            # Evaluate stopping criteria
            if stopping_criteria(all_tokens):
                # Decode all tokens
                output = self.tokenizer.decode(
                    all_tokens.squeeze(-1), skip_special_tokens=True
                )
Olivier Dehaene's avatar
Olivier Dehaene committed
302
303
304
                # Add to the list of finished generations with the original request
                generated_texts.append(GeneratedText(request, output))
            # add to the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
305
            else:
Olivier Dehaene's avatar
Olivier Dehaene committed
306
307
308
309
310
311
312
313
314
                next_batch_keep_indices.append(i)
                # past_key_values is of shape [batch_size * num_heads, ...]
                # so we need to take into account the `num_heads` stride here
                next_batch_past_keep_indices.extend(
                    [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
                )
                next_batch_input_ids.append(next_token)
                next_batch_all_input_ids.append(all_tokens)
                next_batch_size += 1
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
315
316
                new_input_length = input_length + 1
                next_all_input_lengths.append(new_input_length)
Olivier Dehaene's avatar
Olivier Dehaene committed
317
                next_batch_max_sequence_length = max(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
318
                    next_batch_max_sequence_length, new_input_length
Olivier Dehaene's avatar
Olivier Dehaene committed
319
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
320

Olivier Dehaene's avatar
Olivier Dehaene committed
321
322
323
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
324
325

        # If we finished at least one generation
Olivier Dehaene's avatar
Olivier Dehaene committed
326
327
        next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
        if generated_texts:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
328
            # Apply indices to attention mask, past key values and other items that need to be cached
Olivier Dehaene's avatar
Olivier Dehaene committed
329
330
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
                next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
331
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
332
333
334
335
336
            next_batch_input_ids["past_key_values"] = [
                (
                    keys[next_batch_past_keep_indices],
                    values[next_batch_past_keep_indices],
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
337
338
                for keys, values in outputs["past_key_values"]
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
339
340
341
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
342
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
343
344
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
345
346
            ]
        else:
Olivier Dehaene's avatar
Olivier Dehaene committed
347
348
349
350
351
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
            next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
352
353

        # Update attention_mask with padding as we added a new token to input_ids
Olivier Dehaene's avatar
Olivier Dehaene committed
354
        next_batch_input_ids["attention_mask"] = torch.cat(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
355
            [
Olivier Dehaene's avatar
Olivier Dehaene committed
356
357
                next_batch_input_ids["attention_mask"],
                torch.ones((next_batch_size, 1)).to(self.device),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
358
359
360
361
            ],
            dim=1,
        )

Olivier Dehaene's avatar
Olivier Dehaene committed
362
363
364
        next_batch = Batch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
365
            all_input_lengths=next_all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
366
367
368
369
370
371
            input_ids=next_batch_input_ids,
            all_input_ids=next_batch_all_input_ids,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
372
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
373
        return generated_texts, next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
374
375
376


class BLOOMSharded(BLOOM):
Nicolas Patry's avatar
Nicolas Patry committed
377
    def __init__(self, model_name: str):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        super(BLOOM, self).__init__()
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
            self.device = torch.device(f"cuda:{self.rank}")
            dtype = torch.bfloat16
        else:
            self.device = torch.device("cpu")
            dtype = torch.float32

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

        config = AutoConfig.from_pretrained(
            model_name, slow_but_exact=False, tp_parallel=True
        )
        config.pad_token_id = 3
Nicolas Patry's avatar
Nicolas Patry committed
394
        self.num_heads = config.n_head // self.process_group.size()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
395
396
397
398
399
400
401
402

        # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
        # in PyTorch 1.12 and later.
        torch.backends.cuda.matmul.allow_tf32 = True

        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

Nicolas Patry's avatar
Nicolas Patry committed
403
404
405
        # Only download weights for small models
        if self.master and model_name == "bigscience/bloom-560m":
            download_weights(model_name)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
406
407

        torch.distributed.barrier(group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
408
409
410
411
412
        filenames = weight_files(model_name)

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
413
        torch.distributed.barrier(group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        self.load_weights(
            model,
            filenames,
            device=self.device,
            rank=self.rank,
            world_size=self.world_size,
        )
        self.model = model.eval().to(dtype)
        torch.distributed.barrier(group=self.process_group)

    @staticmethod
    def load_weights(
        model, filenames: List[str], device: torch.device, rank: int, world_size: int
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(file, framework="pt", device=str(device)) as f:
                for name in f.keys():
                    full_name = f"transformer.{name}"

                    module_name, param_name = full_name.rsplit(".", 1)
                    module = model.get_submodule(module_name)
                    current_tensor = parameters[full_name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                            tensor = tensor.transpose(1, 0)
                        else:
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                            tensor = tensor.transpose(1, 0)
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

                    tensor = tensor.contiguous()
                    module._parameters[param_name] = tensor
                    if name == "word_embeddings.weight":
                        model.lm_head._parameters["weight"] = tensor
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
485
486
487
488
489
490
491
492
493

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
494
        # Logits are sharded, so we need to gather them
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
495
496
497
498
499
500
501
502
503
504
        logits_shard = outputs.logits[:, -1, :].contiguous()

        batch_size, vocab_shard_size = logits_shard.shape
        vocab_size = self.world_size * vocab_shard_size
        logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
        logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)

        outputs.logits = logits
        return outputs