quantization.py 4.91 KB
Newer Older
1
import json
2
import os
3
from dataclasses import dataclass
4
from typing import Optional
5
6

from huggingface_hub import hf_hub_download
7
8
9
10
11
from text_generation_server.utils.weights import (
    DefaultWeightsLoader,
    UnquantizedWeight,
    WeightsLoader,
)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


@dataclass
class _QuantizerConfig:
    bits: int
    checkpoint_format: Optional[str]
    desc_act: bool
    groupsize: int
    quant_method: str
    sym: bool


# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
    bits = 4
    groupsize = -1
    quant_method = "gptq"
    checkpoint_format = None
    sym = True
    desc_act = False

    filename = "config.json"
    try:
        if os.path.exists(os.path.join(model_id, filename)):
            filename = os.path.join(model_id, filename)
        else:
            filename = hf_hub_download(model_id, filename=filename, revision=revision)
        with open(filename, "r") as f:
            data = json.load(f)
        bits = data["quantization_config"]["bits"]
        groupsize = data["quantization_config"]["group_size"]
        # Order is important here, desc_act is missing on some real models
        quant_method = data["quantization_config"]["quant_method"]
        checkpoint_format = data["quantization_config"].get("checkpoint_format")
        sym = data["quantization_config"]["sym"]
        desc_act = data["quantization_config"]["desc_act"]
    except Exception:
        filename = "quantize_config.json"
        try:
            if os.path.exists(os.path.join(model_id, filename)):
                filename = os.path.join(model_id, filename)
            else:
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
            with open(filename, "r") as f:
                data = json.load(f)
            bits = data["bits"]
            groupsize = data["group_size"]
            sym = data["sym"]
            desc_act = data["desc_act"]
            if "version" in data and data["version"] == "GEMM":
                quant_method = "awq"
        except Exception:
            filename = "quant_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
                with open(filename, "r") as f:
                    data = json.load(f)
                bits = data["w_bit"]
                groupsize = data["q_group_size"]
                desc_act = data["desc_act"]
                if "version" in data and data["version"] == "GEMM":
                    quant_method = "awq"
            except Exception:
                pass

    return _QuantizerConfig(
        bits=bits,
        groupsize=groupsize,
        quant_method=quant_method,
        checkpoint_format=checkpoint_format,
        sym=sym,
        desc_act=desc_act,
    )


def get_loader(
    quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
    quantizer_config = _get_quantizer_config(model_id, revision)
    if quantize in {"awq", "gptq"}:
        from text_generation_server.layers.gptq import GPTQWeightsLoader

        return GPTQWeightsLoader(
            bits=quantizer_config.bits,
            desc_act=quantizer_config.desc_act,
            groupsize=quantizer_config.groupsize,
            quant_method=quantizer_config.quant_method,
            quantize=quantize,
            sym=quantizer_config.sym,
        )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    elif quantize == "bitsandbytes":
        from text_generation_server.layers.bnb import BNBWeight

        return DefaultWeightsLoader(BNBWeight)
    elif quantize == "bitsandbytes-fp4":
        from text_generation_server.layers.bnb import BNBFP4Weight

        return DefaultWeightsLoader(BNBFP4Weight)
    elif quantize == "bitsandbytes-nf4":
        from text_generation_server.layers.bnb import BNBNF4Weight

        return DefaultWeightsLoader(BNBNF4Weight)
    elif quantize == "eetq":
        from text_generation_server.layers.eetq import EETQWeight

        return DefaultWeightsLoader(EETQWeight)
126
127
128
129
    elif quantize == "exl2":
        from text_generation_server.layers.exl2 import Exl2WeightsLoader

        return Exl2WeightsLoader()
130
131
132
133
    elif quantize == "fp8":
        from text_generation_server.layers.fp8 import Fp8Weight

        return DefaultWeightsLoader(Fp8Weight)
134
135
136
137
138
139
140
    elif quantize == "marlin":
        from text_generation_server.layers.marlin import MarlinWeightsLoader

        return MarlinWeightsLoader(
            bits=quantizer_config.bits,
            is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
        )
141
142
    elif quantize is None:
        return DefaultWeightsLoader(UnquantizedWeight)
143
    else:
144
        raise ValueError(f"Unknown quantization method: {quantize}")