model.py 22.5 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
import torch
import torch.distributed

from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

Nicolas Patry's avatar
Nicolas Patry committed
7
8
from accelerate import init_empty_weights
from safetensors import safe_open
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
14
from transformers.models.bloom.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
16
17
18
19
20

from bloom_inference.pb import generate_pb2
from bloom_inference.utils import (
    StoppingCriteria,
    NextTokenChooser,
    initialize_torch_distributed,
Nicolas Patry's avatar
Nicolas Patry committed
21
    weight_files,
22
    download_weights,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
)

25
26
27
28
29
30
31
HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
33
34
35
36
37
torch.manual_seed(0)


@dataclass
class Batch:
    batch_id: int
Olivier Dehaene's avatar
Olivier Dehaene committed
38
    requests: List[generate_pb2.Request]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
    all_input_lengths: List[int]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
43
    input_ids: Dict[str, torch.Tensor]
    all_input_ids: List[torch.Tensor]
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
46
47
48
49
50
51
52
53
    size: int
    max_sequence_length: int

    def to_pb(self):
        return generate_pb2.Batch(
            id=self.batch_id,
            requests=self.requests,
            size=self.size,
            max_sequence_length=self.max_sequence_length,
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
54
55

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
56
57
    def from_pb(
        cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
58
59
60
61
    ) -> "Batch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
62
        all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
63
64
65
66

        # Parse batch
        for r in pb.requests:
            inputs.append(r.inputs)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            all_input_lengths.append(r.input_length)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
68
69
70
71
72
73
74
75
76
77
            next_token_choosers.append(
                NextTokenChooser(
                    temperature=r.parameters.temperature,
                    top_k=r.parameters.top_k,
                    top_p=r.parameters.top_p,
                    do_sample=r.parameters.do_sample,
                )
            )
            stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))

78
79
80
        input_ids = tokenizer(
            inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
        ).to(device)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
        all_input_ids = input_ids["input_ids"].unsqueeze(-1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
83

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
            batch_id=pb.id,
            requests=pb.requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
            all_input_lengths=all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
87
88
89
90
91
92
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
            max_sequence_length=pb.max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
94
95
        )

    @classmethod
Olivier Dehaene's avatar
Olivier Dehaene committed
96
97
98
99
    def concatenate(cls, batches: List["Batch"]) -> "Batch":
        # Used for padding
        total_batch_size = sum(batch.size for batch in batches)
        max_sequence_length = max(batch.max_sequence_length for batch in batches)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100

Olivier Dehaene's avatar
Olivier Dehaene committed
101
        # Batch attributes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
102
        input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
Olivier Dehaene's avatar
Olivier Dehaene committed
103
        requests = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
        all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
105
106
107
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Olivier Dehaene's avatar
Olivier Dehaene committed
108
109
110

        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
111
        start_index = 0
Olivier Dehaene's avatar
Olivier Dehaene committed
112
113
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
            all_input_lengths.extend(batch.all_input_lengths)
Olivier Dehaene's avatar
Olivier Dehaene committed
115
116
117
118
119
120
121
122
123
124
125
126
127
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Slicing end index for this batch
            end_index = start_index + batch.size

            # We only concatenate batches that did at least one step
            if batch.input_ids["input_ids"].shape[1] > 1:
                raise ValueError("Batch input_ids should be of shape (batch_size, 1)")

            # Initialize tensors
            if i == 0:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
128
129
                input_ids["input_ids"] = torch.empty(
                    (total_batch_size, 1),
Olivier Dehaene's avatar
Olivier Dehaene committed
130
131
                    dtype=batch.input_ids["input_ids"].dtype,
                    device=batch.input_ids["input_ids"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
132
133
134
                )
                input_ids["attention_mask"] = torch.zeros(
                    (total_batch_size, max_sequence_length),
Olivier Dehaene's avatar
Olivier Dehaene committed
135
136
                    dtype=batch.input_ids["attention_mask"].dtype,
                    device=batch.input_ids["attention_mask"].device,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
137
138
                )

Olivier Dehaene's avatar
Olivier Dehaene committed
139
140
141
142
143
            # input_ids["input_ids"] is always of shape [batch_size, 1]
            # We do not need to pad it
            input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]

            # We need to slice the attention mask to remove padding from previous steps
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
144
            input_ids["attention_mask"][
Olivier Dehaene's avatar
Olivier Dehaene committed
145
146
                start_index:end_index, -batch.max_sequence_length :
            ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
147

Olivier Dehaene's avatar
Olivier Dehaene committed
148
            for j, past in enumerate(batch.input_ids["past_key_values"]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
152
153
                past_keys = past[0]
                past_values = past[1]

                _, head_dim, padded_sequence_length = past_keys.shape

Olivier Dehaene's avatar
Olivier Dehaene committed
154
                # Reshape the tensors to make slicing easier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
155
                past_keys = past_keys.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
156
                    batch.size, -1, head_dim, padded_sequence_length
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
157
158
                )
                past_values = past_values.view(
Olivier Dehaene's avatar
Olivier Dehaene committed
159
                    batch.size, -1, padded_sequence_length, head_dim
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
160
161
162
                )
                num_heads = past_keys.shape[1]

Olivier Dehaene's avatar
Olivier Dehaene committed
163
164
                # Initialize tensors
                # This will run only once per layer
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                if j == len(input_ids["past_key_values"]):
                    padded_past_keys = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            head_dim,
                            max_sequence_length - 1,
                        ),
                        dtype=past_keys.dtype,
                        device=past_keys.device,
                    )
                    padded_past_values = torch.zeros(
                        (
                            total_batch_size,
                            num_heads,
                            max_sequence_length - 1,
                            head_dim,
                        ),
                        dtype=past_values.dtype,
                        device=past_values.device,
                    )
                    input_ids["past_key_values"].append(
                        [padded_past_keys, padded_past_values]
                    )

Olivier Dehaene's avatar
Olivier Dehaene committed
190
                # We slice the past keys and values to remove the padding from previous batches
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
191
                input_ids["past_key_values"][j][0][
Olivier Dehaene's avatar
Olivier Dehaene committed
192
193
                    start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
                ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
194
195

                input_ids["past_key_values"][j][1][
Olivier Dehaene's avatar
Olivier Dehaene committed
196
197
                    start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
                ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
198

Olivier Dehaene's avatar
Olivier Dehaene committed
199
200
                # If we are on the last batch, we need to reshape the tensors
                if (i + 1) == len(batches):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
201
202
203
204
205
206
207
                    input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
                        j
                    ][0].view(total_batch_size * num_heads, head_dim, -1)
                    input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
                        j
                    ][1].view(total_batch_size * num_heads, -1, head_dim)

Olivier Dehaene's avatar
Olivier Dehaene committed
208
            start_index += batch.size
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
209
210

        return cls(
Olivier Dehaene's avatar
Olivier Dehaene committed
211
212
            batch_id=batches[0].batch_id,
            requests=requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
213
            all_input_lengths=all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
214
215
216
217
218
219
            input_ids=input_ids,
            all_input_ids=all_input_ids,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=total_batch_size,
            max_sequence_length=max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
220
221
222
223
        )


@dataclass
Olivier Dehaene's avatar
Olivier Dehaene committed
224
225
class GeneratedText:
    request: generate_pb2.Request
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
226
227
    output: str

Olivier Dehaene's avatar
Olivier Dehaene committed
228
229
    def to_pb(self) -> generate_pb2.GeneratedText:
        return generate_pb2.GeneratedText(request=self.request, output=self.output)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
230
231
232
233
234
235


class BLOOM:
    def __init__(self, model_name: str):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
Olivier Dehaene's avatar
Olivier Dehaene committed
236
            dtype = torch.bfloat16
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
237
238
        else:
            self.device = torch.device("cpu")
Olivier Dehaene's avatar
Olivier Dehaene committed
239
            dtype = torch.float32
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
240
241
242

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = (
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
243
244
245
246
            AutoModelForCausalLM.from_pretrained(model_name)
            .eval()
            .to(self.device)
            .to(dtype)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
247
248
249
250
251
252
253
254
255
256
257
258
259
        )
        self.num_heads = self.model.base_model.num_heads

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

    def generate_token(
Olivier Dehaene's avatar
Olivier Dehaene committed
260
261
        self, batch: Batch
    ) -> Tuple[List[GeneratedText], Optional[Batch]]:
262
        with torch.inference_mode():
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
263
264
265
            outputs = self.forward(**batch.input_ids)

        # List of indices to cache
Olivier Dehaene's avatar
Olivier Dehaene committed
266
267
268
269
270
271
        next_batch_keep_indices = []
        next_batch_past_keep_indices = []

        # New input_ids for next forward
        next_batch_input_ids = []
        next_batch_all_input_ids = []
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
272
        next_all_input_lengths = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
273

Olivier Dehaene's avatar
Olivier Dehaene committed
274
275
        next_batch_size = 0
        next_batch_max_sequence_length = 0
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
276
277

        # Finished requests
Olivier Dehaene's avatar
Olivier Dehaene committed
278
        generated_texts: List[GeneratedText] = []
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
279
280
281

        # Zipped iterator
        iterator = zip(
Olivier Dehaene's avatar
Olivier Dehaene committed
282
            batch.requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
283
            batch.all_input_lengths,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
284
285
286
287
288
289
290
291
            outputs.logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
Olivier Dehaene's avatar
Olivier Dehaene committed
292
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
293
            input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
294
295
296
297
            logits,
            next_token_chooser,
            stopping_criteria,
            all_tokens,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
298
299
300
301
302
303
304
305
306
307
308
309
310
        ) in enumerate(iterator):
            # Select next token
            next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])

            # Append next token to all tokens
            all_tokens = torch.cat([all_tokens, next_token])

            # Evaluate stopping criteria
            if stopping_criteria(all_tokens):
                # Decode all tokens
                output = self.tokenizer.decode(
                    all_tokens.squeeze(-1), skip_special_tokens=True
                )
Olivier Dehaene's avatar
Olivier Dehaene committed
311
312
313
                # Add to the list of finished generations with the original request
                generated_texts.append(GeneratedText(request, output))
            # add to the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
314
            else:
Olivier Dehaene's avatar
Olivier Dehaene committed
315
316
317
318
319
320
321
322
323
                next_batch_keep_indices.append(i)
                # past_key_values is of shape [batch_size * num_heads, ...]
                # so we need to take into account the `num_heads` stride here
                next_batch_past_keep_indices.extend(
                    [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
                )
                next_batch_input_ids.append(next_token)
                next_batch_all_input_ids.append(all_tokens)
                next_batch_size += 1
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
324
325
                new_input_length = input_length + 1
                next_all_input_lengths.append(new_input_length)
Olivier Dehaene's avatar
Olivier Dehaene committed
326
                next_batch_max_sequence_length = max(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
327
                    next_batch_max_sequence_length, new_input_length
Olivier Dehaene's avatar
Olivier Dehaene committed
328
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
329

Olivier Dehaene's avatar
Olivier Dehaene committed
330
331
332
        # We finished all generations in the batch; there is no next batch
        if not next_batch_keep_indices:
            return generated_texts, None
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
333
334

        # If we finished at least one generation
Olivier Dehaene's avatar
Olivier Dehaene committed
335
336
        next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
        if generated_texts:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
337
            # Apply indices to attention mask, past key values and other items that need to be cached
Olivier Dehaene's avatar
Olivier Dehaene committed
338
339
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
                next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
340
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
341
342
343
344
345
            next_batch_input_ids["past_key_values"] = [
                (
                    keys[next_batch_past_keep_indices],
                    values[next_batch_past_keep_indices],
                )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
346
347
                for keys, values in outputs["past_key_values"]
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
348
349
350
            next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
            next_batch_next_token_choosers = [
                batch.next_token_choosers[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
351
            ]
Olivier Dehaene's avatar
Olivier Dehaene committed
352
353
            next_batch_stopping_criterias = [
                batch.stopping_criterias[i] for i in next_batch_keep_indices
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
354
355
            ]
        else:
Olivier Dehaene's avatar
Olivier Dehaene committed
356
357
358
359
360
            next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
            next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
            next_batch_requests = batch.requests
            next_batch_next_token_choosers = batch.next_token_choosers
            next_batch_stopping_criterias = batch.stopping_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
361
362

        # Update attention_mask with padding as we added a new token to input_ids
Olivier Dehaene's avatar
Olivier Dehaene committed
363
        next_batch_input_ids["attention_mask"] = torch.cat(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
364
            [
Olivier Dehaene's avatar
Olivier Dehaene committed
365
366
                next_batch_input_ids["attention_mask"],
                torch.ones((next_batch_size, 1)).to(self.device),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
367
368
369
370
            ],
            dim=1,
        )

Olivier Dehaene's avatar
Olivier Dehaene committed
371
372
373
        next_batch = Batch(
            batch_id=batch.batch_id,
            requests=next_batch_requests,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
374
            all_input_lengths=next_all_input_lengths,
Olivier Dehaene's avatar
Olivier Dehaene committed
375
376
377
378
379
380
            input_ids=next_batch_input_ids,
            all_input_ids=next_batch_all_input_ids,
            next_token_choosers=next_batch_next_token_choosers,
            stopping_criterias=next_batch_stopping_criterias,
            size=next_batch_size,
            max_sequence_length=next_batch_max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
381
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
382
        return generated_texts, next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
383
384
385


class BLOOMSharded(BLOOM):
386
    def __init__(self, model_name: str, quantize: bool = False):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
387
388
389
390
391
        super(BLOOM, self).__init__()
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
            self.device = torch.device(f"cuda:{self.rank}")
392
            dtype = torch.float16
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
393
394
395
396
397
398
399
400
401
402
        else:
            self.device = torch.device("cpu")
            dtype = torch.float32

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

        config = AutoConfig.from_pretrained(
            model_name, slow_but_exact=False, tp_parallel=True
        )
        config.pad_token_id = 3
Nicolas Patry's avatar
Nicolas Patry committed
403
        self.num_heads = config.n_head // self.process_group.size()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
404
405
406
407
408
409
410
411

        # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
        # in PyTorch 1.12 and later.
        torch.backends.cuda.matmul.allow_tf32 = True

        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

Nicolas Patry's avatar
Nicolas Patry committed
412
413
414
        # Only download weights for small models
        if self.master and model_name == "bigscience/bloom-560m":
            download_weights(model_name)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
415
416

        torch.distributed.barrier(group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
417
418
419
420
421
        filenames = weight_files(model_name)

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
422
        torch.distributed.barrier(group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
423
424
425
        self.load_weights(
            model,
            filenames,
426
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
427
428
429
430
431
432
433
434
435
            device=self.device,
            rank=self.rank,
            world_size=self.world_size,
        )
        self.model = model.eval().to(dtype)
        torch.distributed.barrier(group=self.process_group)

    @staticmethod
    def load_weights(
436
437
438
439
440
441
        model,
        filenames: List[str],
        quantize: bool,
        device: torch.device,
        rank: int,
        world_size: int,
Nicolas Patry's avatar
Nicolas Patry committed
442
443
444
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
445
446
447
            with safe_open(
                file, framework="pt", device=str(device) if not quantize else "cpu"
            ) as f:
Nicolas Patry's avatar
Nicolas Patry committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                for name in f.keys():
                    full_name = f"transformer.{name}"

                    module_name, param_name = full_name.rsplit(".", 1)
                    module = model.get_submodule(module_name)
                    current_tensor = parameters[full_name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                            tensor = tensor.transpose(1, 0)
                        else:
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                            tensor = tensor.transpose(1, 0)
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

                    tensor = tensor.contiguous()
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

                    if quantize:
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine"
                            )

                        if (
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
                        ):
                            tensor = Int8Params(
                                tensor.transpose(1, 0),
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

                            def replace_linear(state, in_features, out_features):
                                def linear(input, weight, bias):
                                    size_out = input.size()[:-1] + (out_features,)
                                    input = input.view(-1, in_features)
                                    out = torch.empty(
                                        size_out, device=input.device, dtype=input.dtype
                                    )
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        out=out.view(-1, out_features),
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

                                    return out.view(size_out)

                                return linear

                            module.linear = replace_linear(
                                state, module.in_features, module.out_features
                            )

                        else:
                            tensor = tensor.to(device)

Nicolas Patry's avatar
Nicolas Patry committed
560
561
562
                    module._parameters[param_name] = tensor
                    if name == "word_embeddings.weight":
                        model.lm_head._parameters["weight"] = tensor
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
563
564
565
566
567
568
569
570
571

    def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
572
        # Logits are sharded, so we need to gather them
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
573
574
575
576
577
578
579
580
581
582
        logits_shard = outputs.logits[:, -1, :].contiguous()

        batch_size, vocab_shard_size = logits_shard.shape
        vocab_size = self.world_size * vocab_shard_size
        logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
        logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)

        outputs.logits = logits
        return outputs