__init__.py 7.67 KB
Newer Older
1
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
3
from __future__ import annotations

4
5
import builtins
import inspect
6
from typing import TYPE_CHECKING, Dict, Optional, Type
7

Liangsheng Yin's avatar
Liangsheng Yin committed
8
import torch
9
10
11
12

try:
    from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
    from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
14
15
16
    from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
        CompressedTensorsW8A8Fp8MoEMethod,
        CompressedTensorsWNA16MoEMethod,
    )
17
18
19
20
21
22
23
24
25
26
27
28
    from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
    from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
    from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
    from vllm.model_executor.layers.quantization.gguf import GGUFConfig
    from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
        GPTQMarlin24Config,
    )
    from vllm.model_executor.layers.quantization.marlin import MarlinConfig
    from vllm.model_executor.layers.quantization.qqq import QQQConfig
    from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

    VLLM_AVAILABLE = True
Lianmin Zheng's avatar
Lianmin Zheng committed
29
except ImportError as e:
30
    VLLM_AVAILABLE = False
Lianmin Zheng's avatar
Lianmin Zheng committed
31
    VLLM_IMPORT_ERROR = e
32

33
    # Define empty classes as placeholders when vllm is not available
34
    class DummyConfig:
35
36
        def override_quantization_method(self, *args, **kwargs):
            return None
37

38
39
40
41
42
    AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
        ExpertsInt8Config
    ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
        Int8TpuConfig
    ) = DummyConfig
43

44

45
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
46
from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
48
49
50
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
    CompressedTensorsConfig,
)
Ying Sheng's avatar
Ying Sheng committed
51
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
52
53
54
55
56

is_mxfp_supported = mxfp_supported()
if is_mxfp_supported:
    from sglang.srt.layers.quantization.fp4 import MxFp4Config

57
from sglang.srt.layers.quantization.fp8 import Fp8Config
58
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
59
60
61
62
from sglang.srt.layers.quantization.modelopt_quant import (
    ModelOptFp4Config,
    ModelOptFp8Config,
)
AniZpZ's avatar
AniZpZ committed
63
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
Ying Sheng's avatar
Ying Sheng committed
64
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
65
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
HandH1998's avatar
HandH1998 committed
66
from sglang.srt.layers.quantization.qoq import QoQConfig
67
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
HandH1998's avatar
HandH1998 committed
68
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
69
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
70

71
72
73
if TYPE_CHECKING:
    from sglang.srt.layers.moe.topk import TopKOutput

74
75
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
76
    "fp8": Fp8Config,
77
    "blockwise_int8": BlockInt8Config,
78
    "modelopt": ModelOptFp8Config,
79
    "modelopt_fp4": ModelOptFp4Config,
80
    "w8a8_int8": W8A8Int8Config,
HandH1998's avatar
HandH1998 committed
81
    "w8a8_fp8": W8A8Fp8Config,
82
83
84
85
    "awq": AWQConfig,
    "awq_marlin": AWQMarlinConfig,
    "gptq": GPTQConfig,
    "gptq_marlin": GPTQMarlinConfig,
AniZpZ's avatar
AniZpZ committed
86
    "moe_wna16": MoeWNA16Config,
87
    "compressed-tensors": CompressedTensorsConfig,
HandH1998's avatar
HandH1998 committed
88
    "qoq": QoQConfig,
89
    "w4afp8": W4AFp8Config,
90
    "petit_nvfp4": PetitNvFp4Config,
91
}
Ying Sheng's avatar
Ying Sheng committed
92
93
94
95
96
97
98
99
100
101


if is_cuda():
    BASE_QUANTIZATION_METHODS.update(
        {
            "quark": Mxfp4Config,
            "mxfp4": Mxfp4Config,
        }
    )
elif is_mxfp_supported and is_hip():
102
103
104
105
106
107
    BASE_QUANTIZATION_METHODS.update(
        {
            "quark": MxFp4Config,
            "mxfp4": MxFp4Config,
        }
    )
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "deepspeedfp": DeepSpeedFPConfig,
    "tpu_int8": Int8TpuConfig,
    "fbgemm_fp8": FBGEMMFp8Config,
    "marlin": MarlinConfig,
    "gguf": GGUFConfig,
    "gptq_marlin_24": GPTQMarlin24Config,
    "bitsandbytes": BitsAndBytesConfig,
    "qqq": QQQConfig,
    "experts_int8": ExpertsInt8Config,
}

QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
123

124
125
126

def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
    if quantization not in QUANTIZATION_METHODS:
127
128
129
130
        raise ValueError(
            f"Invalid quantization method: {quantization}. "
            f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
        )
131
132
133
    if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
        raise ValueError(
            f"{quantization} quantization requires some operators from vllm. "
Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
            f"Please install vllm by `pip install vllm==0.9.0.1`\n"
            f"Import error: {VLLM_IMPORT_ERROR}"
136
137
        )

138
139
140
    return QUANTIZATION_METHODS[quantization]


141
original_isinstance = builtins.isinstance
142
143


144
145
146
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
    """
    Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
147
    can recognize sglang layers
148
    """
149
150
    if not VLLM_AVAILABLE:
        return
151

152
153
154
    if reverse:
        builtins.isinstance = original_isinstance
        return
155

156
157
158
159
160
    from vllm.model_executor.layers.fused_moe import FusedMoE
    from vllm.model_executor.layers.linear import LinearBase
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
161

162
163
164
165
166
    from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
    from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
    from sglang.srt.layers.vocab_parallel_embedding import (
        VocabParallelEmbedding as PatchedVocabParallelEmbedding,
    )
Yineng Zhang's avatar
Yineng Zhang committed
167

168
169
170
171
172
173
174
175
176
177
    def patched_isinstance(obj, classinfo):
        if classinfo is LinearBase:
            return original_isinstance(obj, PatchedLinearBase)
        if classinfo is FusedMoE:
            return original_isinstance(obj, PatchedFusedMoE)
        if classinfo is VocabParallelEmbedding:
            return original_isinstance(obj, PatchedVocabParallelEmbedding)
        return original_isinstance(obj, classinfo)

    builtins.isinstance = patched_isinstance
Yineng Zhang's avatar
Yineng Zhang committed
178
179


180
181
182
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
    """
    Monkey patch the apply function of vllm's FusedMoEMethodBase.
183
    Convert sglang arguments to vllm arguments.
184
    """
185
186
187
188
189
190
191
192
193
    original_apply = class_obj.apply
    sig = inspect.signature(original_apply)
    param_names = list(sig.parameters.keys())
    has_correction_bias = "e_score_correction_bias" in param_names

    def new_apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
194
195
        topk_output: TopKOutput,
        *,
196
        activation: str = "silu",
Chang Su's avatar
Chang Su committed
197
        apply_router_weight_on_input: bool = False,
198
199
        inplace: bool = True,
        no_combine: bool = False,
200
        routed_scaling_factor: Optional[float] = None,
201
202
203
204
205
206
207
208
    ):
        assert activation == "silu"
        assert inplace and not no_combine

        kwargs = {
            "self": self,
            "layer": layer,
            "x": x,
209
            "topk_output": topk_output,
210
211
212
213
        }
        return original_apply(**kwargs)

    setattr(class_obj, "apply", new_apply)
214
215
216


def monkey_patch_quant_configs():
217
    """Apply all monkey patches in one place."""
Liangsheng Yin's avatar
Liangsheng Yin committed
218

219
220
    monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
    monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
221
222


223
# Only apply monkey patches if vllm is available
224
225
if VLLM_AVAILABLE:
    monkey_patch_quant_configs()