weights.py 24.3 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from pathlib import Path
3
from typing import Dict, List, Optional, Tuple, Union
4
from safetensors import safe_open, SafetensorError
5
import torch
6
from loguru import logger
7
8
from huggingface_hub import hf_hub_download
import json
9
from text_generation_server.utils.log import log_once
10
11
12


class Weights:
13
14
15
16
17
18
19
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
20
        prefix: Optional[str] = None,
21
    ):
22
23
24
25
26
27
28
29
30
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
31
32
33
        if aliases is None:
            aliases = {}
        self.aliases = aliases
34
35
36
37
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
38
        self.prefix = prefix
39
40
41
42
43
44
45
46
47
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

48
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
49
50
51
52
53
54
55
56
57
58
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
59
60
61
62
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
63
        raise RuntimeError(f"weight {tensor_name} does not exist")
64
65

    def _get_slice(self, tensor_name: str):
66
        filename, tensor_name = self.get_filename(tensor_name)
67
68
69
70
71
72
73
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
74
    def get_tensor(self, tensor_name: str, to_device=True):
75
        filename, tensor_name = self.get_filename(tensor_name)
76
77
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
78
        # Special case for gptq which shouldn't convert
79
80
81
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
82
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
83
84
        if to_device:
            tensor = tensor.to(device=self.device)
85
86
        return tensor

87
    def get_partial_sharded(self, tensor_name: str, dim: int):
88
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
89
90
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
91
92
93
94
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
95
        block_size = (size + world_size - 1) // world_size
96
97
98
99
100
101
102
103
104
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
105
        # Special case for gptq which shouldn't convert
106
107
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
108
            tensor = tensor.to(dtype=self.dtype)
109
110
        tensor = tensor.to(device=self.device)
        return tensor
111

112
113
114
115
116
117
118
119
120
121
122
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

123
    def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
xiaobin's avatar
xiaobin committed
124
125
        slice_ = self._get_slice(name)
        total_size = slice_.get_shape()[1]
126
127
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
128
129
130
        world_size = self.process_group.size()
        rank = self.process_group.rank()

131
        weights = []
132
133
134
135
136
137
138
139
140
141
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
            ), f"Prepacked qkv cannot be sharded across {world_size} shards"
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
            weights.append(slice_[:, block_offset + start : block_offset + stop])
            block_offset += block_size
142
143

        weight = torch.cat(weights, dim=1)
xiaobin's avatar
xiaobin committed
144
145
146
        weight = weight.to(device=self.device)
        return weight

147
148
149
150
151
152
153
154
155
156
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        quantize: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
            prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
        )
Nicolas Patry's avatar
Nicolas Patry committed
157
158
159
160

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

161
162
163
    def get_weights_col_packed(
        self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
    ):
xiaobin's avatar
xiaobin committed
164
165
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
166
167
168
169
170
171
172
        already alternating Q,K,V within the main tensor.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
173
        """
174
        if quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
175
176
            from text_generation_server.layers.gptq import GPTQWeight

xiaobin's avatar
xiaobin committed
177
            try:
178
                qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
xiaobin's avatar
xiaobin committed
179
180
            except RuntimeError:
                raise RuntimeError(
181
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
182
183
                )

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
184
185
            bits, groupsize, _, quant_method = self._get_gptq_params()

186
187
            qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
            scales = self._get_qweight(f"{prefix}.scales", block_sizes)
xiaobin's avatar
xiaobin committed
188
            scales = scales.to(dtype=self.dtype)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
189
190

            if quantize == "gptq" and quant_method == "gptq":
191
                g_idx = self.get_tensor(f"{prefix}.g_idx")
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
192
193
194
195
            elif quantize == "gptq" and quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
196
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
197
198
199
200
201
202
203
204
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
                    torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
                    // groupsize
                ).to(dtype=torch.int32)
205
206
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
207

208
209
210
211
212
213
214
215
216
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=False,
            )
217
218
219
        elif quantize == "marlin":
            from text_generation_server.layers.marlin import MarlinWeight

220
221
            B = self._get_qweight(f"{prefix}.B", block_sizes)
            s = self._get_qweight(f"{prefix}.s", block_sizes)
222
            weight = MarlinWeight(B=B, s=s)
xiaobin's avatar
xiaobin committed
223
        else:
OlivierDehaene's avatar
OlivierDehaene committed
224
            slice_ = self._get_slice(f"{prefix}.weight")
xiaobin's avatar
xiaobin committed
225
            total_size = slice_.get_shape()[0]
226
227
228
229
            block_sizes = _blocks_to_block_sizes(
                total_size=total_size, blocks=block_sizes
            )

xiaobin's avatar
xiaobin committed
230
231
232
            world_size = self.process_group.size()
            rank = self.process_group.rank()

Nicolas Patry's avatar
Nicolas Patry committed
233
            tensors = []
234
235
236
237
238
239
240
241
242
            block_offset = 0
            for block_size in block_sizes:
                assert (
                    block_size % world_size == 0
                ), f"Prepacked weights cannot be sharded across {world_size} shards"
                shard_block_size = block_size // world_size
                start = rank * shard_block_size
                stop = (rank + 1) * shard_block_size
                tensor = slice_[block_offset + start : block_offset + stop]
Nicolas Patry's avatar
Nicolas Patry committed
243
                tensors.append(tensor)
244
                block_offset += block_size
Nicolas Patry's avatar
Nicolas Patry committed
245
            weight = torch.cat(tensors, dim=0)
xiaobin's avatar
xiaobin committed
246
247
248
249
            weight = weight.to(device=self.device)
            weight = weight.to(dtype=self.dtype)
        return weight

250
251
    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
252
253
            from text_generation_server.layers.exl2 import Exl2Weight

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

276
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
277
278
279
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
280
281
            from text_generation_server.layers.gptq import GPTQWeight

282
            try:
283
284
285
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
286
            except RuntimeError:
287
                raise RuntimeError(
288
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
289
290
291
292
293
294
295
296
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
297

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
298
299
            bits, groupsize, desc_act, quant_method = self._get_gptq_params()

Nicolas Patry's avatar
Nicolas Patry committed
300
            from text_generation_server.layers.gptq import HAS_EXLLAMA
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
301
302
303
304
305
306

            use_exllama = (
                bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
            )

            if quantize == "gptq" and quant_method == "gptq":
307
308
309
310
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
311
312
313
314
            elif quantize == "gptq" and quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
315
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
316
317
318
319
320
321
322
323
324
325
326
327
328
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // bits), device=qweight.device
                        )
                        // groupsize
                    ).to(dtype=torch.int32)
329
330
            else:
                g_idx = None
331

332
333
334
335
336
337
338
339
340
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        elif quantize == "marlin":
            from text_generation_server.layers.marlin import MarlinWeight

            try:
                B = torch.cat(
                    [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
                )
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                )
            s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)

            weight = MarlinWeight(B=B, s=s)

356
357
358
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
359

360
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
361

xiaobin's avatar
xiaobin committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
376
        return tensor
377
378

    def get_multi_weights_row(self, prefix: str, quantize: str):
379
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
380
381
            from text_generation_server.layers.exl2 import Exl2Weight

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
403
            use_exllama = True
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
404
            bits, groupsize, desc_act, quant_method = self._get_gptq_params()
405
406
407
408

            if bits != 4:
                use_exllama = False

409
410
411
412
            if desc_act:
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
413
414
415
416
417
418
419
420
421
422
423
424
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            if quant_method == "gptq":
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
            elif quant_method == "awq":
                g_idx = None

425
426
            if self.process_group.size() > 1:
                if g_idx is not None:
427
428
429
430
431
432
433
434
435
436
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
                                [i // groupsize for i in range(g_idx.shape[0])],
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
437
438
439
440
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

Nicolas Patry's avatar
Nicolas Patry committed
441
442
443
444
445
            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )
446

447
            if use_exllama:
448
449
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
450
451
                        log_once(
                            logger.warning,
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
452
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
453
                        )
454
455
                    use_exllama = False
                else:
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
456
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
457

458
            if use_exllama and groupsize != -1:
Nicolas Patry's avatar
Nicolas Patry committed
459
460
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
461
462
463
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
464

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
465
            if use_exllama and g_idx is not None:
466
                g_idx = g_idx - g_idx[0]
467

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
468
469
470
471
            if quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
472
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // bits), device=qweight.device
                        )
                        // groupsize
                    ).to(dtype=torch.int32)

487
488
489
490
491
492
493
494
495
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
496
        elif quantize == "awq":
Nicolas Patry's avatar
Nicolas Patry committed
497
498
            from text_generation_server.layers.gptq import GPTQWeight

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
499
            bits, groupsize, _, _ = self._get_gptq_params()
500
501
502
503
504
505
506
507
508
509
510
511

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
512

513
514
515
516
517
518
519
520
521
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
522
523
524
525
526
527
528
529
530
531
        elif quantize == "marlin":
            from text_generation_server.layers.marlin import MarlinWeight

            try:
                B = self.get_sharded(f"{prefix}.B", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

532
533
534
535
536
537
538
            num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
            if num_groups == 1:
                # The number of groups is 1 when group_size == -1. share
                # scales between all shards in this case.
                s = self.get_tensor(f"{prefix}.s")
            else:
                s = self.get_sharded(f"{prefix}.s", dim=0)
539
540
            weight = MarlinWeight(B=B, s=s)

541
542
543
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
544

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
545
    def _get_gptq_params(self) -> Tuple[int, int, int, str]:
546
547
548
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
549
            desc_act = False
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
550
            quant_method = "gptq"
551
552
        except (SafetensorError, RuntimeError) as e:
            try:
553
554
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
555
                desc_act = getattr(self, "gptq_desc_act", False)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
556
                quant_method = getattr(self, "quant_method", "gptq")
557
558
559
            except Exception:
                raise e

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
560
        return bits, groupsize, desc_act, quant_method
561

OlivierDehaene's avatar
OlivierDehaene committed
562
    def _set_gptq_params(self, model_id, revision):
563
        filename = "config.json"
564
        try:
565
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
566
567
                filename = os.path.join(model_id, filename)
            else:
OlivierDehaene's avatar
OlivierDehaene committed
568
569
570
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
571
572
            with open(filename, "r") as f:
                data = json.load(f)
573
574
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
575
            # Order is important here, desc_act is missing on some real models
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
576
            self.quant_method = data["quantization_config"]["quant_method"]
577
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
578
        except Exception:
579
580
581
582
583
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
OlivierDehaene's avatar
OlivierDehaene committed
584
585
586
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
587
588
589
590
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
591
                self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
592
593
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
594
            except Exception:
595
596
597
598
599
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
OlivierDehaene's avatar
OlivierDehaene committed
600
601
602
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
603
604
605
606
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
607
                    self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
608
609
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
610
611
                except Exception:
                    pass
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks