Unverified Commit bf8d07a6 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: patch linear base (#2915)

parent ab317936
...@@ -16,9 +16,6 @@ from vllm.distributed import ( ...@@ -16,9 +16,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
from vllm.model_executor.layers.linear import LinearBase
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
PackedColumnParameter, PackedColumnParameter,
...@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase): class ReplicatedLinear(LinearBase):
"""Replicated linear layer. """Replicated linear layer.
......
...@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
def fp8_get_quant_method(self, layer, prefix): def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config.""" """Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped, is_layer_skipped,
) )
from sglang.srt.layers.linear import UnquantizedLinearMethod from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
...@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix): ...@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix):
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, GPTQMarlinMoEMethod,
) )
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix): ...@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix):
def awq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.awq_marlin import ( from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod, AWQMarlinLinearMethod,
AWQMoEMethod, AWQMoEMethod,
) )
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix): ...@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix):
return None return None
def patch_vllm_linear_base_isinstance():
import builtins
from vllm.model_executor.layers.linear import LinearBase
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
original_isinstance = builtins.isinstance
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
def apply_monkey_patches(): def apply_monkey_patches():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
...@@ -115,6 +131,7 @@ def apply_monkey_patches(): ...@@ -115,6 +131,7 @@ def apply_monkey_patches():
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
patch_vllm_linear_base_isinstance()
# Apply patches when module is imported # Apply patches when module is imported
apply_monkey_patches() apply_monkey_patches()
......
...@@ -9,7 +9,6 @@ from torch.nn import Module ...@@ -9,7 +9,6 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
...@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
......
...@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional ...@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, apply_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
......
...@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.model_executor.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self) return W8A8Int8LinearMethod(self)
......
...@@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): ...@@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
def monkey_patch_vllm_gguf_config(): def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gguf import ( from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig, GGUFConfig,
GGUFEmbeddingMethod, GGUFEmbeddingMethod,
GGUFLinearMethod, GGUFLinearMethod,
) )
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
def get_quant_method_with_embedding_replaced( def get_quant_method_with_embedding_replaced(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment