"vscode:/vscode.git/clone" did not exist on "9f444cd7ff4a0b3cd5e3695f891b96cf6fa9ac78"
Unverified Commit 2cc9eeab authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[4/n]decouple quantization implementation from vLLM dependency (#9191)


Co-authored-by: default avatarAniZpZ <aniz1905@gmail.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 63d82a77
......@@ -30,14 +30,9 @@ jobs:
- name: Install dependencies
run: |
bash scripts/ci/ci_install_dependency.sh
pip install "vllm==0.10.0"
pip install "openai==1.99.1"
pip install "bitsandbytes>=0.44.0"
# NOTE: The latest sgl-kernel depends on torch 2.8.0 but the latest vllm depends on torch 2.7.0
# so they are not compatible. Here we install the old sgl-kernel to make the test pass.
# TODO: remove this once vllm supports torch 2.8.0.
pip install "sgl-kernel==0.2.9"
pip install "sgl-kernel==0.3.5"
- name: Run vLLM dependency tests
timeout-minutes: 60
......
......@@ -58,7 +58,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.4.post1",
"sgl-kernel==0.3.5",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
......
......@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
"0.3.4",
"0.3.5",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -55,13 +55,7 @@ if is_mxfp_supported:
from sglang.srt.layers.quantization.fp4 import MxFp4Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import (
GPTQConfig,
GPTQLinearMethod,
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
......@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
......@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"modelopt_fp4": ModelOptFp4Config,
"w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config,
"awq": AWQConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"gptq_marlin": GPTQMarlinConfig,
"moe_wna16": MoeWNA16Config,
"compressed-tensors": CompressedTensorsConfig,
"qoq": QoQConfig,
......@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip():
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
}
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
......@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization]
def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
return None
original_isinstance = builtins.isinstance
......@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
......
......@@ -35,22 +35,18 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
warnings.warn(
f"Using kernels directly from vllm. This might lead to performance degradation or "
f"missing functionalities as certain kernels may not be optimized. "
)
except ImportError:
ops = None
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import awq_dequantize, fused_marlin_moe
from sgl_kernel import (
awq_dequantize,
awq_marlin_moe_repack,
awq_marlin_repack,
fused_marlin_moe,
)
elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
......@@ -519,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.workspace = marlin_make_workspace(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
marlin_qweight = awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
......@@ -687,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
marlin_w13_qweight = awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
......@@ -696,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
marlin_w2_qweight = awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
......
......@@ -46,17 +46,12 @@ from sglang.srt.layers.quantization.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import fused_marlin_moe
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
logger = logging.getLogger(__name__)
......@@ -86,9 +81,7 @@ def gptq_marlin_moe_repack(
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
return output
......@@ -205,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
elif isinstance(layer, FusedMoE):
if isinstance(layer, FusedMoE):
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
return None
else:
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
class GPTQMarlinConfig(QuantizationConfig):
......@@ -531,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.g_idx.data = torch.empty(
(0,), dtype=torch.int, device=layer.g_idx.device
)
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
def apply(
self,
......@@ -542,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm(
output = gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
......@@ -727,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
x.data = gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
......
......@@ -24,7 +24,7 @@ from sglang.srt.layers.quantization.utils import (
pack_cols,
unpack_cols,
)
from sglang.srt.utils import get_device_capability
from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase
......@@ -34,6 +34,11 @@ try:
except ImportError:
ops = None
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gptq_marlin_gemm
logger = logging.getLogger(__name__)
ScalarType, scalar_types = get_scalar_types()
......@@ -458,7 +463,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
......@@ -509,7 +514,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
......
......@@ -149,9 +149,9 @@ suites = {
"vllm_dependency_test": [
TestFile("quant/test_awq.py", 163),
TestFile("test_bnb.py", 5),
TestFile("test_gguf.py", 96),
TestFile("test_gptqmodel_dynamic.py", 102),
TestFile("test_vllm_dependency.py", 185),
# TestFile("test_gguf.py", 96),
],
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment