Unverified Commit 2cc9eeab authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[4/n]decouple quantization implementation from vLLM dependency (#9191)


Co-authored-by: default avatarAniZpZ <aniz1905@gmail.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 63d82a77
...@@ -30,14 +30,9 @@ jobs: ...@@ -30,14 +30,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
bash scripts/ci/ci_install_dependency.sh bash scripts/ci/ci_install_dependency.sh
pip install "vllm==0.10.0"
pip install "openai==1.99.1"
pip install "bitsandbytes>=0.44.0" pip install "bitsandbytes>=0.44.0"
# NOTE: The latest sgl-kernel depends on torch 2.8.0 but the latest vllm depends on torch 2.7.0 pip install "sgl-kernel==0.3.5"
# so they are not compatible. Here we install the old sgl-kernel to make the test pass.
# TODO: remove this once vllm supports torch 2.8.0.
pip install "sgl-kernel==0.2.9"
- name: Run vLLM dependency tests - name: Run vLLM dependency tests
timeout-minutes: 60 timeout-minutes: 60
......
...@@ -58,7 +58,7 @@ runtime_common = [ ...@@ -58,7 +58,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.4.post1", "sgl-kernel==0.3.5",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
......
...@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.3.4", "0.3.5",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -55,13 +55,7 @@ if is_mxfp_supported: ...@@ -55,13 +55,7 @@ if is_mxfp_supported:
from sglang.srt.layers.quantization.fp4 import MxFp4Config from sglang.srt.layers.quantization.fp4 import MxFp4Config
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import ( from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
GPTQConfig,
GPTQLinearMethod,
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config, ModelOptFp4Config,
ModelOptFp8Config, ModelOptFp8Config,
...@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config ...@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"modelopt_fp4": ModelOptFp4Config, "modelopt_fp4": ModelOptFp4Config,
"w8a8_int8": W8A8Int8Config, "w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config, "w8a8_fp8": W8A8Fp8Config,
"awq": AWQConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"gptq_marlin": GPTQMarlinConfig,
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"qoq": QoQConfig, "qoq": QoQConfig,
...@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip(): ...@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip():
# VLLM-dependent quantization methods # VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = { VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig, "tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
} }
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS} QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
...@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
return None
original_isinstance = builtins.isinstance original_isinstance = builtins.isinstance
...@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
def monkey_patch_quant_configs(): def monkey_patch_quant_configs():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
......
...@@ -35,22 +35,18 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param ...@@ -35,22 +35,18 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
warnings.warn(
f"Using kernels directly from vllm. This might lead to performance degradation or "
f"missing functionalities as certain kernels may not be optimized. "
)
except ImportError:
ops = None
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, fused_marlin_moe from sgl_kernel import (
awq_dequantize,
awq_marlin_moe_repack,
awq_marlin_repack,
fused_marlin_moe,
)
elif _is_hip: elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import ( from sglang.srt.layers.quantization.awq_triton import (
...@@ -519,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -519,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.workspace = marlin_make_workspace(device) layer.workspace = marlin_make_workspace(device)
# Repack weights from AWQ format to marlin format. # Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack( marlin_qweight = awq_marlin_repack(
layer.qweight, layer.qweight,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
...@@ -687,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -687,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
requires_grad=False, requires_grad=False,
) )
marlin_w13_qweight = ops.awq_marlin_moe_repack( marlin_w13_qweight = awq_marlin_moe_repack(
layer.w13_qweight, layer.w13_qweight,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1], size_k=layer.w13_qweight.shape[1],
...@@ -696,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -696,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack( marlin_w2_qweight = awq_marlin_moe_repack(
layer.w2_qweight, layer.w2_qweight,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1], size_k=layer.w2_qweight.shape[1],
......
...@@ -46,17 +46,12 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -46,17 +46,12 @@ from sglang.srt.layers.quantization.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import fused_marlin_moe from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -86,9 +81,7 @@ def gptq_marlin_moe_repack( ...@@ -86,9 +81,7 @@ def gptq_marlin_moe_repack(
dtype=b_q_weight.dtype, dtype=b_q_weight.dtype,
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack( output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
return output return output
...@@ -205,11 +198,12 @@ class GPTQConfig(QuantizationConfig): ...@@ -205,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, FusedMoE):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
elif isinstance(layer, FusedMoE):
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin") raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
return None else:
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
...@@ -531,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -531,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.g_idx.data = torch.empty( layer.g_idx.data = torch.empty(
(0,), dtype=torch.int, device=layer.g_idx.device (0,), dtype=torch.int, device=layer.g_idx.device
) )
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
def apply( def apply(
self, self,
...@@ -542,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -542,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm( output = gptq_gemm(
reshaped_x, reshaped_x,
layer.qweight, layer.qweight,
layer.qzeros, layer.qzeros,
...@@ -727,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -727,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = torch.ops.sgl_kernel.gptq_marlin_repack( x.data = gptq_marlin_repack(
x.data.contiguous(), x.data.contiguous(),
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0], size_k=c.partition_weight_shape[0],
......
...@@ -24,7 +24,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -24,7 +24,7 @@ from sglang.srt.layers.quantization.utils import (
pack_cols, pack_cols,
unpack_cols, unpack_cols,
) )
from sglang.srt.utils import get_device_capability from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
...@@ -34,6 +34,11 @@ try: ...@@ -34,6 +34,11 @@ try:
except ImportError: except ImportError:
ops = None ops = None
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gptq_marlin_gemm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ScalarType, scalar_types = get_scalar_types() ScalarType, scalar_types = get_scalar_types()
...@@ -458,7 +463,7 @@ def apply_gptq_marlin_linear( ...@@ -458,7 +463,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
output = ops.gptq_marlin_gemm( output = gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
...@@ -509,7 +514,7 @@ def apply_awq_marlin_linear( ...@@ -509,7 +514,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
output = ops.gptq_marlin_gemm( output = gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
......
...@@ -149,9 +149,9 @@ suites = { ...@@ -149,9 +149,9 @@ suites = {
"vllm_dependency_test": [ "vllm_dependency_test": [
TestFile("quant/test_awq.py", 163), TestFile("quant/test_awq.py", 163),
TestFile("test_bnb.py", 5), TestFile("test_bnb.py", 5),
TestFile("test_gguf.py", 96),
TestFile("test_gptqmodel_dynamic.py", 102), TestFile("test_gptqmodel_dynamic.py", 102),
TestFile("test_vllm_dependency.py", 185), TestFile("test_vllm_dependency.py", 185),
# TestFile("test_gguf.py", 96),
], ],
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment