__init__.py 11.9 KB
Newer Older
1
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
3
import builtins
import inspect
4
5
6
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union
7

Liangsheng Yin's avatar
Liangsheng Yin committed
8
import torch
9
10
11

try:
    from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
13
14
15
    from vllm.model_executor.layers.quantization.awq_marlin import (
        AWQMarlinConfig,
        AWQMoEMethod,
    )
16
    from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
18
19
20
    from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
        CompressedTensorsW8A8Fp8MoEMethod,
        CompressedTensorsWNA16MoEMethod,
    )
21
22
23
24
    from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
    from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
    from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
    from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
26
27
28
29
    from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
    from vllm.model_executor.layers.quantization.gptq_marlin import (
        GPTQMarlinLinearMethod,
        GPTQMarlinMoEMethod,
    )
30
31
32
33
34
35
36
37
38
39
40
41
42
    from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
        GPTQMarlin24Config,
    )
    from vllm.model_executor.layers.quantization.marlin import MarlinConfig
    from vllm.model_executor.layers.quantization.qqq import QQQConfig
    from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False

    # Define empty classes as placeholders when vllm is not available
    class DummyConfig:
43
44
        def override_quantization_method(self, *args, **kwargs):
            return None
45

laixin's avatar
laixin committed
46
    AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
47
48
49
50
51
        DeepSpeedFPConfig
    ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
        MarlinConfig
    ) = QQQConfig = Int8TpuConfig = DummyConfig

52

53
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
laixin's avatar
laixin committed
54
from sglang.srt.layers.quantization.awq import AWQConfig
55
from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
57
58
59
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
    CompressedTensorsConfig,
)
60
from sglang.srt.layers.quantization.fp8 import Fp8Config
61
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
62
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
AniZpZ's avatar
AniZpZ committed
63
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
HandH1998's avatar
HandH1998 committed
64
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
65
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
67
68
69
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    UnquantizedEmbeddingMethod,
)
70

71
72
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
    "fp8": Fp8Config,
74
    "blockwise_int8": BlockInt8Config,
75
    "modelopt": ModelOptFp8Config,
76
    "w8a8_int8": W8A8Int8Config,
HandH1998's avatar
HandH1998 committed
77
    "w8a8_fp8": W8A8Fp8Config,
AniZpZ's avatar
AniZpZ committed
78
    "moe_wna16": MoeWNA16Config,
79
    "compressed-tensors": CompressedTensorsConfig,
80
81
}

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "awq": AWQConfig,
    "deepspeedfp": DeepSpeedFPConfig,
    "tpu_int8": Int8TpuConfig,
    "fbgemm_fp8": FBGEMMFp8Config,
    "marlin": MarlinConfig,
    "gguf": GGUFConfig,
    "gptq_marlin_24": GPTQMarlin24Config,
    "awq_marlin": AWQMarlinConfig,
    "bitsandbytes": BitsAndBytesConfig,
    "qqq": QQQConfig,
    "experts_int8": ExpertsInt8Config,
    "gptq_marlin": GPTQMarlinConfig,
    "gptq": GPTQConfig,
}

QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
101

102
103
104

def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
    if quantization not in QUANTIZATION_METHODS:
105
106
107
108
        raise ValueError(
            f"Invalid quantization method: {quantization}. "
            f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
        )
109
110
111
112
113
114
    if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
        raise ValueError(
            f"{quantization} quantization requires some operators from vllm. "
            "Pleaes install vllm by `pip install vllm==0.7.2`"
        )

115
116
117
    return QUANTIZATION_METHODS[quantization]


118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
    weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
    if isinstance(weight_bits, int):
        config.weight_bits = weight_bits
    group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
    if isinstance(group_size, int):
        config.group_size = group_size
    desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
    if isinstance(desc_act, bool):
        config.desc_act = desc_act

    config.pack_factor = 32 // config.weight_bits  # packed into int32
    if config.get_name() == "gptq_marlin":
        is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
        if isinstance(is_sym, bool):
            config.is_sym = is_sym

        if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
            raise ValueError(
                "Unsupported quantization config: "
                f"bits={config.weight_bits}, sym={config.is_sym}"
            )

        config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
    elif config.get_name() == "gptq":
        if config.weight_bits not in [2, 3, 4, 8]:
            raise ValueError(
                "Currently, only 2/3/4/8-bit weight quantization is "
                f"supported for GPTQ, but got {config.weight_bits} bits."
            )


def get_dynamic_override(
    config: QuantizationConfig,
    layer_name: str,
    key: Optional[str] = None,
    default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
    for pattern, pattern_dict in config.dynamic.items():
        # Negative match: matched modules are excluded from quantized init
        if pattern.startswith("-:"):
            if re.match(pattern.removeprefix("-:"), layer_name):
                return False
        # Positive match: matched modules have quant properties overrides
        # base quant config
        elif re.match(pattern.removeprefix("+:"), layer_name):
            if key is None:
                return pattern_dict
            else:
                return pattern_dict.get(key, default_value)
    return default_value


def get_linear_quant_method(
    config: QuantizationConfig,
    layer: torch.nn.Module,
    prefix: str,
    linear_method_cls: type,
):
    cloned_config = deepcopy(config)
    parallel_lm_head_quantized = (
        isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
    )

    if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
        # False = skip module, None = no override, else = Positive match
        if (
            get_dynamic_override(  # noqa: E712
                cloned_config, layer_name=prefix  # noqa: E712
            )
            == False
        ):  # noqa: E712
            if parallel_lm_head_quantized:
                return UnquantizedEmbeddingMethod()
            return UnquantizedLinearMethod()

        if prefix:
            # Dynamic per module/layer rules may override base config
            override_config(cloned_config, prefix=prefix)

        return linear_method_cls(cloned_config)
    return None


204
def gptq_get_quant_method(self, layer, prefix):
205
206
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE

207
208
    if isinstance(layer, FusedMoE):
        return GPTQMarlinMoEMethod(self)
209

210
211
212
213
214
215
216
    if isinstance(self, GPTQConfig):
        return get_linear_quant_method(
            self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
        )
    elif isinstance(self, GPTQMarlinConfig):
        return get_linear_quant_method(
            self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
217
        )
218
219
220
    return None


221
original_isinstance = builtins.isinstance
222
223


224
225
226
227
228
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
    """
    Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
    can recognize sglang layers
    """
229
230
    if not VLLM_AVAILABLE:
        return
231

232
233
234
    if reverse:
        builtins.isinstance = original_isinstance
        return
235

236
237
238
239
240
    from vllm.model_executor.layers.fused_moe import FusedMoE
    from vllm.model_executor.layers.linear import LinearBase
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
241

242
243
244
245
246
    from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
    from sglang.srt.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding as PatchedVocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
247

248
249
250
251
252
253
254
255
256
257
    def patched_isinstance(obj, classinfo):
        if classinfo is LinearBase:
            return original_isinstance(obj, PatchedLinearBase)
        if classinfo is FusedMoE:
            return original_isinstance(obj, PatchedFusedMoE)
        if classinfo is VocabParallelEmbedding:
            return original_isinstance(obj, PatchedVocabParallelEmbedding)
        return original_isinstance(obj, classinfo)

    builtins.isinstance = patched_isinstance
Yineng Zhang's avatar
Yineng Zhang committed
258
259


260
261
262
263
264
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
    """
    Monkey patch the apply function of vllm's FusedMoEMethodBase.
    Convert sglang arguments to vllm arguments.
    """
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    original_apply = class_obj.apply
    sig = inspect.signature(original_apply)
    param_names = list(sig.parameters.keys())
    has_correction_bias = "e_score_correction_bias" in param_names

    def new_apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
        correction_bias: Optional[torch.Tensor] = None,
        activation: str = "silu",
Chang Su's avatar
Chang Su committed
283
        apply_router_weight_on_input: bool = False,
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        inplace: bool = True,
        no_combine: bool = False,
    ):
        assert activation == "silu"
        assert inplace and not no_combine

        kwargs = {
            "self": self,
            "layer": layer,
            "x": x,
            "router_logits": router_logits,
            "top_k": top_k,
            "renormalize": renormalize,
            "use_grouped_topk": use_grouped_topk,
            "topk_group": topk_group,
            "num_expert_group": num_expert_group,
            "custom_routing_function": custom_routing_function,
        }
        if correction_bias is not None:
            if not has_correction_bias:
                raise ValueError(
                    "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
                )
            kwargs["e_score_correction_bias"] = correction_bias
        return original_apply(**kwargs)

    setattr(class_obj, "apply", new_apply)
311
312
313


def monkey_patch_quant_configs():
314
    """Apply all monkey patches in one place."""
315
316
    setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
    setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
Liangsheng Yin's avatar
Liangsheng Yin committed
317

318
319
320
321
    monkey_patch_moe_apply(AWQMoEMethod)
    monkey_patch_moe_apply(GPTQMarlinMoEMethod)
    monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
    monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
322
323


324
325
326
# Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
    monkey_patch_quant_configs()