Unverified Commit 64cf868e authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

chore: cleanup quant deps (#12268)

parent ea399527
...@@ -7,31 +7,14 @@ from typing import TYPE_CHECKING, Dict, Optional, Type ...@@ -7,31 +7,14 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
import torch import torch
try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config,
)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
VLLM_AVAILABLE = True
except ImportError as e:
VLLM_AVAILABLE = False
VLLM_IMPORT_ERROR = e
# Define empty classes as placeholders when vllm is not available # Define empty classes as placeholders when vllm is not available
class DummyConfig: class DummyConfig:
def override_quantization_method(self, *args, **kwargs): def override_quantization_method(self, *args, **kwargs):
return None return None
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
ExpertsInt8Config CompressedTensorsConfig = DummyConfig
) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.quantization.auto_round import AutoRoundConfig from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
...@@ -62,7 +45,7 @@ _is_mxfp_supported = mxfp_supported() ...@@ -62,7 +45,7 @@ _is_mxfp_supported = mxfp_supported()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
# Base quantization methods that don't depend on vllm # Base quantization methods
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config, "fp8": Fp8Config,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
...@@ -102,19 +85,8 @@ elif _is_mxfp_supported and is_hip(): ...@@ -102,19 +85,8 @@ elif _is_mxfp_supported and is_hip():
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
} }
) )
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"marlin": MarlinConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
}
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS} QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...@@ -123,50 +95,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -123,50 +95,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
f"Invalid quantization method: {quantization}. " f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}" f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
) )
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError(
f"{quantization} quantization requires some operators from vllm. "
f"Please install vllm by `pip install vllm==0.9.0.1`\n"
f"Import error: {VLLM_IMPORT_ERROR}"
)
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
original_isinstance = builtins.isinstance original_isinstance = builtins.isinstance
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers
"""
if not VLLM_AVAILABLE:
return
if reverse:
builtins.isinstance = original_isinstance
return
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
...@@ -77,7 +77,6 @@ from sglang.srt.layers.dp_attention import ( ...@@ -77,7 +77,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
...@@ -730,7 +729,6 @@ class ModelRunner: ...@@ -730,7 +729,6 @@ class ModelRunner:
# Load the model # Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm # Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state() monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()
with self.memory_saver_adapter.region( with self.memory_saver_adapter.region(
GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_WEIGHTS,
...@@ -742,7 +740,6 @@ class ModelRunner: ...@@ -742,7 +740,6 @@ class ModelRunner:
device_config=DeviceConfig(self.device, self.gpu_id), device_config=DeviceConfig(self.device, self.gpu_id),
) )
monkey_patch_vllm_parallel_state(reverse=True) monkey_patch_vllm_parallel_state(reverse=True)
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
get_offloader().post_init() get_offloader().post_init()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment