__init__.py 7.66 KB
Newer Older
1
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
3
from __future__ import annotations

4
5
import builtins
import inspect
6
from typing import TYPE_CHECKING, Dict, Optional, Type
7

Liangsheng Yin's avatar
Liangsheng Yin committed
8
import torch
9
10
11
12

try:
    from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
    from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
14
15
16
    from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
        CompressedTensorsW8A8Fp8MoEMethod,
        CompressedTensorsWNA16MoEMethod,
    )
17
18
19
20
21
22
23
24
25
26
27
28
    from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
    from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
    from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
    from vllm.model_executor.layers.quantization.gguf import GGUFConfig
    from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
        GPTQMarlin24Config,
    )
    from vllm.model_executor.layers.quantization.marlin import MarlinConfig
    from vllm.model_executor.layers.quantization.qqq import QQQConfig
    from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

    VLLM_AVAILABLE = True
Lianmin Zheng's avatar
Lianmin Zheng committed
29
except ImportError as e:
30
    VLLM_AVAILABLE = False
Lianmin Zheng's avatar
Lianmin Zheng committed
31
    VLLM_IMPORT_ERROR = e
32

33
    # Define empty classes as placeholders when vllm is not available
34
    class DummyConfig:
35
36
        def override_quantization_method(self, *args, **kwargs):
            return None
37

38
39
40
41
42
    AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
        ExpertsInt8Config
    ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
        Int8TpuConfig
    ) = DummyConfig
43

44

45
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
46
from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
48
49
50
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
    CompressedTensorsConfig,
)
51
from sglang.srt.layers.quantization.fp8 import Fp8Config
52
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
53
54
55
56
from sglang.srt.layers.quantization.modelopt_quant import (
    ModelOptFp4Config,
    ModelOptFp8Config,
)
AniZpZ's avatar
AniZpZ committed
57
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
Ying Sheng's avatar
Ying Sheng committed
58
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
59
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
HandH1998's avatar
HandH1998 committed
60
from sglang.srt.layers.quantization.qoq import QoQConfig
61
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
HandH1998's avatar
HandH1998 committed
62
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
63
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
64
65
66
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported

_is_mxfp_supported = mxfp_supported()
67

68
69
70
if TYPE_CHECKING:
    from sglang.srt.layers.moe.topk import TopKOutput

71
72
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
    "fp8": Fp8Config,
74
    "blockwise_int8": BlockInt8Config,
75
    "modelopt": ModelOptFp8Config,
76
    "modelopt_fp4": ModelOptFp4Config,
77
    "w8a8_int8": W8A8Int8Config,
HandH1998's avatar
HandH1998 committed
78
    "w8a8_fp8": W8A8Fp8Config,
79
80
81
82
    "awq": AWQConfig,
    "awq_marlin": AWQMarlinConfig,
    "gptq": GPTQConfig,
    "gptq_marlin": GPTQMarlinConfig,
AniZpZ's avatar
AniZpZ committed
83
    "moe_wna16": MoeWNA16Config,
84
    "compressed-tensors": CompressedTensorsConfig,
HandH1998's avatar
HandH1998 committed
85
    "qoq": QoQConfig,
86
    "w4afp8": W4AFp8Config,
87
    "petit_nvfp4": PetitNvFp4Config,
88
}
Ying Sheng's avatar
Ying Sheng committed
89
90
91
92
93
94
95
96
97


if is_cuda():
    BASE_QUANTIZATION_METHODS.update(
        {
            "quark": Mxfp4Config,
            "mxfp4": Mxfp4Config,
        }
    )
98
99
100
elif _is_mxfp_supported and is_hip():
    from sglang.srt.layers.quantization.quark.quark import QuarkConfig

101
102
    BASE_QUANTIZATION_METHODS.update(
        {
103
104
            "quark": QuarkConfig,
            "mxfp4": Mxfp4Config,
105
106
        }
    )
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "deepspeedfp": DeepSpeedFPConfig,
    "tpu_int8": Int8TpuConfig,
    "fbgemm_fp8": FBGEMMFp8Config,
    "marlin": MarlinConfig,
    "gguf": GGUFConfig,
    "gptq_marlin_24": GPTQMarlin24Config,
    "bitsandbytes": BitsAndBytesConfig,
    "qqq": QQQConfig,
    "experts_int8": ExpertsInt8Config,
}

QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
122

123
124
125

def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
    if quantization not in QUANTIZATION_METHODS:
126
127
128
129
        raise ValueError(
            f"Invalid quantization method: {quantization}. "
            f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
        )
130
131
132
    if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
        raise ValueError(
            f"{quantization} quantization requires some operators from vllm. "
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
            f"Please install vllm by `pip install vllm==0.9.0.1`\n"
            f"Import error: {VLLM_IMPORT_ERROR}"
135
136
        )

137
138
139
    return QUANTIZATION_METHODS[quantization]


140
original_isinstance = builtins.isinstance
141
142


143
144
145
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
    """
    Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
146
    can recognize sglang layers
147
    """
148
149
    if not VLLM_AVAILABLE:
        return
150

151
152
153
    if reverse:
        builtins.isinstance = original_isinstance
        return
154

155
156
157
158
159
    from vllm.model_executor.layers.fused_moe import FusedMoE
    from vllm.model_executor.layers.linear import LinearBase
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
160

161
162
163
164
165
    from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
    from sglang.srt.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding as PatchedVocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
166

167
168
169
170
171
172
173
174
175
176
    def patched_isinstance(obj, classinfo):
        if classinfo is LinearBase:
            return original_isinstance(obj, PatchedLinearBase)
        if classinfo is FusedMoE:
            return original_isinstance(obj, PatchedFusedMoE)
        if classinfo is VocabParallelEmbedding:
            return original_isinstance(obj, PatchedVocabParallelEmbedding)
        return original_isinstance(obj, classinfo)

    builtins.isinstance = patched_isinstance
Yineng Zhang's avatar
Yineng Zhang committed
177
178


179
180
181
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
    """
    Monkey patch the apply function of vllm's FusedMoEMethodBase.
182
    Convert sglang arguments to vllm arguments.
183
    """
184
185
186
187
188
189
190
191
192
    original_apply = class_obj.apply
    sig = inspect.signature(original_apply)
    param_names = list(sig.parameters.keys())
    has_correction_bias = "e_score_correction_bias" in param_names

    def new_apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
193
194
        topk_output: TopKOutput,
        *,
195
        activation: str = "silu",
Chang Su's avatar
Chang Su committed
196
        apply_router_weight_on_input: bool = False,
197
198
        inplace: bool = True,
        no_combine: bool = False,
199
        routed_scaling_factor: Optional[float] = None,
200
201
202
203
204
205
206
207
    ):
        assert activation == "silu"
        assert inplace and not no_combine

        kwargs = {
            "self": self,
            "layer": layer,
            "x": x,
208
            "topk_output": topk_output,
209
210
211
212
        }
        return original_apply(**kwargs)

    setattr(class_obj, "apply", new_apply)
213
214
215


def monkey_patch_quant_configs():
216
    """Apply all monkey patches in one place."""
Liangsheng Yin's avatar
Liangsheng Yin committed
217

218
219
    monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
    monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
220
221


222
# Only apply monkey patches if vllm is available
223
224
if VLLM_AVAILABLE:
    monkey_patch_quant_configs()