__init__.py 8.49 KB
Newer Older
1
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
3
from __future__ import annotations

4
5
import builtins
import inspect
6
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
7

Liangsheng Yin's avatar
Liangsheng Yin committed
8
import torch
9
10
11
12

try:
    from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
    from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
14
15
16
    from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
        CompressedTensorsW8A8Fp8MoEMethod,
        CompressedTensorsWNA16MoEMethod,
    )
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
    from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
    from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
    from vllm.model_executor.layers.quantization.gguf import GGUFConfig
    from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
        GPTQMarlin24Config,
    )
    from vllm.model_executor.layers.quantization.marlin import MarlinConfig
    from vllm.model_executor.layers.quantization.qqq import QQQConfig
    from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False

32
    # Define empty classes as placeholders when vllm is not available
33
    class DummyConfig:
34
35
        def override_quantization_method(self, *args, **kwargs):
            return None
36

37
38
39
40
41
    AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
        ExpertsInt8Config
    ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
        Int8TpuConfig
    ) = DummyConfig
42

43

44
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
45
from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
47
48
49
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
    CompressedTensorsConfig,
)
Ying Sheng's avatar
Ying Sheng committed
50
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
51
52
53
54
55

is_mxfp_supported = mxfp_supported()
if is_mxfp_supported:
    from sglang.srt.layers.quantization.fp4 import MxFp4Config

56
from sglang.srt.layers.quantization.fp8 import Fp8Config
57
58
from sglang.srt.layers.quantization.gptq import (
    GPTQConfig,
59
    GPTQLinearMethod,
60
    GPTQMarlinConfig,
61
    GPTQMarlinLinearMethod,
62
63
    GPTQMarlinMoEMethod,
)
64
65
66
67
from sglang.srt.layers.quantization.modelopt_quant import (
    ModelOptFp4Config,
    ModelOptFp8Config,
)
AniZpZ's avatar
AniZpZ committed
68
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
Ying Sheng's avatar
Ying Sheng committed
69
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
70
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
HandH1998's avatar
HandH1998 committed
71
from sglang.srt.layers.quantization.qoq import QoQConfig
72
from sglang.srt.layers.quantization.utils import get_linear_quant_method
73
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
HandH1998's avatar
HandH1998 committed
74
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
75
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
76

77
78
79
if TYPE_CHECKING:
    from sglang.srt.layers.moe.topk import TopKOutput

80
81
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
82
    "fp8": Fp8Config,
83
    "blockwise_int8": BlockInt8Config,
84
    "modelopt": ModelOptFp8Config,
85
    "modelopt_fp4": ModelOptFp4Config,
86
    "w8a8_int8": W8A8Int8Config,
HandH1998's avatar
HandH1998 committed
87
    "w8a8_fp8": W8A8Fp8Config,
AniZpZ's avatar
AniZpZ committed
88
    "moe_wna16": MoeWNA16Config,
89
    "compressed-tensors": CompressedTensorsConfig,
HandH1998's avatar
HandH1998 committed
90
    "qoq": QoQConfig,
91
    "w4afp8": W4AFp8Config,
92
    "petit_nvfp4": PetitNvFp4Config,
93
}
Ying Sheng's avatar
Ying Sheng committed
94
95
96
97
98
99
100
101
102
103


if is_cuda():
    BASE_QUANTIZATION_METHODS.update(
        {
            "quark": Mxfp4Config,
            "mxfp4": Mxfp4Config,
        }
    )
elif is_mxfp_supported and is_hip():
104
105
106
107
108
109
    BASE_QUANTIZATION_METHODS.update(
        {
            "quark": MxFp4Config,
            "mxfp4": MxFp4Config,
        }
    )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "awq": AWQConfig,
    "deepspeedfp": DeepSpeedFPConfig,
    "tpu_int8": Int8TpuConfig,
    "fbgemm_fp8": FBGEMMFp8Config,
    "marlin": MarlinConfig,
    "gguf": GGUFConfig,
    "gptq_marlin_24": GPTQMarlin24Config,
    "awq_marlin": AWQMarlinConfig,
    "bitsandbytes": BitsAndBytesConfig,
    "qqq": QQQConfig,
    "experts_int8": ExpertsInt8Config,
    "gptq_marlin": GPTQMarlinConfig,
    "gptq": GPTQConfig,
}

QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
129

130
131
132

def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
    if quantization not in QUANTIZATION_METHODS:
133
134
135
136
        raise ValueError(
            f"Invalid quantization method: {quantization}. "
            f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
        )
137
138
139
    if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
        raise ValueError(
            f"{quantization} quantization requires some operators from vllm. "
140
            "Please install vllm by `pip install vllm==0.9.0.1`"
141
142
        )

143
144
145
    return QUANTIZATION_METHODS[quantization]


146
def gptq_get_quant_method(self, layer, prefix):
147
148
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE

149
150
    if isinstance(layer, FusedMoE):
        return GPTQMarlinMoEMethod(self)
151

152
153
154
155
156
157
158
    if isinstance(self, GPTQConfig):
        return get_linear_quant_method(
            self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
        )
    elif isinstance(self, GPTQMarlinConfig):
        return get_linear_quant_method(
            self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
159
        )
160
161
162
    return None


163
original_isinstance = builtins.isinstance
164
165


166
167
168
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
    """
    Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
169
    can recognize sglang layers
170
    """
171
172
    if not VLLM_AVAILABLE:
        return
173

174
175
176
    if reverse:
        builtins.isinstance = original_isinstance
        return
177

178
179
180
181
182
    from vllm.model_executor.layers.fused_moe import FusedMoE
    from vllm.model_executor.layers.linear import LinearBase
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
183

184
185
186
187
188
    from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
    from sglang.srt.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding as PatchedVocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
189

190
191
192
193
194
195
196
197
198
199
    def patched_isinstance(obj, classinfo):
        if classinfo is LinearBase:
            return original_isinstance(obj, PatchedLinearBase)
        if classinfo is FusedMoE:
            return original_isinstance(obj, PatchedFusedMoE)
        if classinfo is VocabParallelEmbedding:
            return original_isinstance(obj, PatchedVocabParallelEmbedding)
        return original_isinstance(obj, classinfo)

    builtins.isinstance = patched_isinstance
Yineng Zhang's avatar
Yineng Zhang committed
200
201


202
203
204
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
    """
    Monkey patch the apply function of vllm's FusedMoEMethodBase.
205
    Convert sglang arguments to vllm arguments.
206
    """
207
208
209
210
211
212
213
214
215
    original_apply = class_obj.apply
    sig = inspect.signature(original_apply)
    param_names = list(sig.parameters.keys())
    has_correction_bias = "e_score_correction_bias" in param_names

    def new_apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
216
217
        topk_output: TopKOutput,
        *,
218
        activation: str = "silu",
Chang Su's avatar
Chang Su committed
219
        apply_router_weight_on_input: bool = False,
220
221
        inplace: bool = True,
        no_combine: bool = False,
222
        routed_scaling_factor: Optional[float] = None,
223
224
225
226
227
228
229
230
    ):
        assert activation == "silu"
        assert inplace and not no_combine

        kwargs = {
            "self": self,
            "layer": layer,
            "x": x,
231
            "topk_output": topk_output,
232
233
234
235
        }
        return original_apply(**kwargs)

    setattr(class_obj, "apply", new_apply)
236
237
238


def monkey_patch_quant_configs():
239
    """Apply all monkey patches in one place."""
240
241
    setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
    setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
Liangsheng Yin's avatar
Liangsheng Yin committed
242

243
244
245
    monkey_patch_moe_apply(GPTQMarlinMoEMethod)
    monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
    monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
246
247


248
# Only apply monkey patches if vllm is available
249
250
if VLLM_AVAILABLE:
    monkey_patch_quant_configs()