weights.py 20.9 KB
Newer Older
1
from dataclasses import dataclass, field
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3
from pathlib import Path
4
from typing import List, Dict, Optional, Set, Tuple, Union
5
from safetensors import safe_open, SafetensorError
6
import torch
7
from loguru import logger
8
9
from huggingface_hub import hf_hub_download
import json
10
from text_generation_server.utils.log import log_once
11
12
13


class Weights:
14
15
16
17
18
19
20
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
21
        prefix: Optional[str] = None,
22
    ):
23
24
25
26
27
28
29
30
31
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
32
33
34
        if aliases is None:
            aliases = {}
        self.aliases = aliases
35
36
37
38
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
39
        self.prefix = prefix
40
41
42
43
44
45
46
47
48
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

49
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
50
51
52
53
54
55
56
57
58
59
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
60
61
62
63
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
64
        raise RuntimeError(f"weight {tensor_name} does not exist")
65
66

    def _get_slice(self, tensor_name: str):
67
        filename, tensor_name = self.get_filename(tensor_name)
68
69
70
71
72
73
74
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
75
    def get_tensor(self, tensor_name: str, to_device=True):
76
        filename, tensor_name = self.get_filename(tensor_name)
77
78
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
79
        # Special case for gptq which shouldn't convert
80
81
82
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
83
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
84
85
        if to_device:
            tensor = tensor.to(device=self.device)
86
87
        return tensor

88
    def get_partial_sharded(self, tensor_name: str, dim: int):
89
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
90
91
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
92
93
94
95
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
96
        block_size = (size + world_size - 1) // world_size
97
98
99
100
101
102
103
104
105
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
106
        # Special case for gptq which shouldn't convert
107
108
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
109
            tensor = tensor.to(dtype=self.dtype)
110
111
        tensor = tensor.to(device=self.device)
        return tensor
112

113
114
115
116
117
118
119
120
121
122
123
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

xiaobin's avatar
xiaobin committed
124
125
126
127
128
129
130
131
    def _get_qweight(self, name: str):
        slice_ = self._get_slice(name)
        total_size = slice_.get_shape()[1]
        assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
        single_size = total_size // 3
        world_size = self.process_group.size()
        rank = self.process_group.rank()

OlivierDehaene's avatar
OlivierDehaene committed
132
133
134
        assert (
            single_size % world_size == 0
        ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
xiaobin's avatar
xiaobin committed
135
136
137
138
        block_size = single_size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        q = slice_[:, start:stop]
OlivierDehaene's avatar
OlivierDehaene committed
139
140
141
        k = slice_[:, start + single_size : stop + single_size]
        v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
        weight = torch.cat([q, k, v], dim=1)
xiaobin's avatar
xiaobin committed
142
143
144
145
        weight = weight.to(device=self.device)
        return weight

    def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
Nicolas Patry's avatar
Nicolas Patry committed
146
147
148
149
150
151
        return self.get_weights_col_packed(prefix, quantize, 3)

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

    def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
xiaobin's avatar
xiaobin committed
152
153
154
155
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
        already alternating Q,K,V within the main tensor
        """
156
        if quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
157
158
            from text_generation_server.layers.gptq import GPTQWeight

xiaobin's avatar
xiaobin committed
159
            try:
OlivierDehaene's avatar
OlivierDehaene committed
160
                qweight = self._get_qweight(f"{prefix}.qweight")
xiaobin's avatar
xiaobin committed
161
162
            except RuntimeError:
                raise RuntimeError(
163
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
164
165
                )

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
166
167
            bits, groupsize, _, quant_method = self._get_gptq_params()

OlivierDehaene's avatar
OlivierDehaene committed
168
169
            qzeros = self._get_qweight(f"{prefix}.qzeros")
            scales = self._get_qweight(f"{prefix}.scales")
xiaobin's avatar
xiaobin committed
170
            scales = scales.to(dtype=self.dtype)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
171
172

            if quantize == "gptq" and quant_method == "gptq":
173
                g_idx = self.get_tensor(f"{prefix}.g_idx")
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
174
175
176
177
            elif quantize == "gptq" and quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
178
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
179
180
181
182
183
184
185
186
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
                    torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
                    // groupsize
                ).to(dtype=torch.int32)
187
188
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
189

190
191
192
193
194
195
196
197
198
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=False,
            )
xiaobin's avatar
xiaobin committed
199
        else:
OlivierDehaene's avatar
OlivierDehaene committed
200
            slice_ = self._get_slice(f"{prefix}.weight")
xiaobin's avatar
xiaobin committed
201
            total_size = slice_.get_shape()[0]
Nicolas Patry's avatar
Nicolas Patry committed
202
203
            assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
            single_size = total_size // blocks
xiaobin's avatar
xiaobin committed
204
205
206
            world_size = self.process_group.size()
            rank = self.process_group.rank()

OlivierDehaene's avatar
OlivierDehaene committed
207
208
209
            assert (
                single_size % world_size == 0
            ), f"Prepacked qkv cannot be sharded across {world_size} shards"
xiaobin's avatar
xiaobin committed
210
211
212
            block_size = single_size // world_size
            start = rank * block_size
            stop = (rank + 1) * block_size
Nicolas Patry's avatar
Nicolas Patry committed
213
214
215
216
217
            tensors = []
            for i in range(blocks):
                tensor = slice_[start + i * single_size : stop + i * single_size]
                tensors.append(tensor)
            weight = torch.cat(tensors, dim=0)
xiaobin's avatar
xiaobin committed
218
219
220
221
            weight = weight.to(device=self.device)
            weight = weight.to(dtype=self.dtype)
        return weight

222
223
    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
224
225
            from text_generation_server.layers.exl2 import Exl2Weight

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

248
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
249
250
251
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
252
253
            from text_generation_server.layers.gptq import GPTQWeight

254
            try:
255
256
257
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
258
            except RuntimeError:
259
                raise RuntimeError(
260
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
261
262
263
264
265
266
267
268
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
269

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
270
271
            bits, groupsize, desc_act, quant_method = self._get_gptq_params()

Nicolas Patry's avatar
Nicolas Patry committed
272
            from text_generation_server.layers.gptq import HAS_EXLLAMA
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
273
274
275
276
277
278

            use_exllama = (
                bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
            )

            if quantize == "gptq" and quant_method == "gptq":
279
280
281
282
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
283
284
285
286
            elif quantize == "gptq" and quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
287
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
288
289
290
291
292
293
294
295
296
297
298
299
300
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // bits), device=qweight.device
                        )
                        // groupsize
                    ).to(dtype=torch.int32)
301
302
            else:
                g_idx = None
303

304
305
306
307
308
309
310
311
312
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
313
314
315
316
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
317

xiaobin's avatar
xiaobin committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
332
        return tensor
333
334

    def get_multi_weights_row(self, prefix: str, quantize: str):
335
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
336
337
            from text_generation_server.layers.exl2 import Exl2Weight

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
359
            use_exllama = True
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
360
            bits, groupsize, desc_act, quant_method = self._get_gptq_params()
361
362
363
364

            if bits != 4:
                use_exllama = False

365
366
367
368
            if desc_act:
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
369
370
371
372
373
374
375
376
377
378
379
380
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            if quant_method == "gptq":
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
            elif quant_method == "awq":
                g_idx = None

381
382
            if self.process_group.size() > 1:
                if g_idx is not None:
383
384
385
386
387
388
389
390
391
392
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
                                [i // groupsize for i in range(g_idx.shape[0])],
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
393
394
395
396
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

Nicolas Patry's avatar
Nicolas Patry committed
397
398
399
400
401
            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )
402

403
            if use_exllama:
404
405
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
406
407
                        log_once(
                            logger.warning,
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
408
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
409
                        )
410
411
                    use_exllama = False
                else:
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
412
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
413

414
            if use_exllama and groupsize != -1:
Nicolas Patry's avatar
Nicolas Patry committed
415
416
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
417
418
419
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
420

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
421
            if use_exllama and g_idx is not None:
422
                g_idx = g_idx - g_idx[0]
423

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
424
425
426
427
            if quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
428
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // bits), device=qweight.device
                        )
                        // groupsize
                    ).to(dtype=torch.int32)

443
444
445
446
447
448
449
450
451
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
452
        elif quantize == "awq":
Nicolas Patry's avatar
Nicolas Patry committed
453
454
            from text_generation_server.layers.gptq import GPTQWeight

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
455
            bits, groupsize, _, _ = self._get_gptq_params()
456
457
458
459
460
461
462
463
464
465
466
467

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
468

469
470
471
472
473
474
475
476
477
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=bits,
                groupsize=groupsize,
                use_exllama=use_exllama,
            )
478
479
480
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
481

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
482
    def _get_gptq_params(self) -> Tuple[int, int, int, str]:
483
484
485
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
486
            desc_act = False
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
487
            quant_method = "gptq"
488
489
        except (SafetensorError, RuntimeError) as e:
            try:
490
491
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
492
                desc_act = getattr(self, "gptq_desc_act", False)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
493
                quant_method = getattr(self, "quant_method", "gptq")
494
495
496
            except Exception:
                raise e

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
497
        return bits, groupsize, desc_act, quant_method
498

OlivierDehaene's avatar
OlivierDehaene committed
499
    def _set_gptq_params(self, model_id, revision):
500
        filename = "config.json"
501
        try:
502
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
503
504
                filename = os.path.join(model_id, filename)
            else:
OlivierDehaene's avatar
OlivierDehaene committed
505
506
507
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
508
509
            with open(filename, "r") as f:
                data = json.load(f)
510
511
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
512
            # Order is important here, desc_act is missing on some real models
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
513
            self.quant_method = data["quantization_config"]["quant_method"]
514
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
515
        except Exception:
516
517
518
519
520
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
OlivierDehaene's avatar
OlivierDehaene committed
521
522
523
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
524
525
526
527
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
528
                self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
529
530
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
531
            except Exception:
532
533
534
535
536
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
OlivierDehaene's avatar
OlivierDehaene committed
537
538
539
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
540
541
542
543
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
544
                    self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
545
546
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
547
548
                except Exception:
                    pass