Unverified Commit 14a4d80e authored by Fan Yin's avatar Fan Yin Committed by GitHub
Browse files

[8/n] decouple quantization impl from vllm dependency - gguf srt (#11964)


Co-authored-by: default avatarPeng Zhang <zhuangsen.zp@antgroup.com>
parent 1053e1be
...@@ -12,7 +12,6 @@ try: ...@@ -12,7 +12,6 @@ try:
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
) )
...@@ -32,9 +31,7 @@ except ImportError as e: ...@@ -32,9 +31,7 @@ except ImportError as e:
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
ExpertsInt8Config ExpertsInt8Config
) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = ( ) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
DummyConfig
)
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
...@@ -45,6 +42,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -45,6 +42,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
from sglang.srt.layers.quantization.gguf import GGUFConfig
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config, ModelOptFp4Config,
...@@ -75,6 +73,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -75,6 +73,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w8a8_fp8": W8A8Fp8Config, "w8a8_fp8": W8A8Fp8Config,
"awq": AWQConfig, "awq": AWQConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
"gguf": GGUFConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
...@@ -108,7 +107,6 @@ VLLM_QUANTIZATION_METHODS = { ...@@ -108,7 +107,6 @@ VLLM_QUANTIZATION_METHODS = {
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig, "tpu_int8": Int8TpuConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/ab3e80042eac24dd362408e6d63ad98768046359/vllm/model_executor/layers/quantization/gguf.py
from __future__ import annotations
import logging
import warnings
from typing import TYPE_CHECKING, Any, List, Optional
import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import is_cuda, is_hip, is_xpu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()
if _is_cuda:
from sgl_kernel import gelu_and_mul, moe_align_block_size, moe_sum, silu_and_mul
from sgl_kernel.quantization import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
else:
warnings.warn(f"Only CUDA support GGUF q uantization currently.")
logger = logging.getLogger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, modules_to_not_convert: list[str] | None = None) -> None:
super().__init__()
self.modules_to_not_convert = modules_to_not_convert or []
def __repr__(self) -> str:
return "GGUFConfig()"
def get_scaled_act_names(self) -> List[str]:
return []
def get_name(self) -> "str":
return "gguf"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(modules_to_not_convert)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return None
def is_layer_skipped_gguf(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def fused_mul_mat_gguf(
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
) -> torch.Tensor:
if qweight_type in IMATRIX_QUANT_TYPES:
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
else:
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0:
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
y = ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y
def fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
silu_and_mul(out, x)
elif activation == "gelu":
gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out
out_hidden_states = torch.empty_like(x)
# unless we decent expert reuse we are better off running moe_vec kernel
if (
qweight_type2 in MMQ_QUANT_TYPES
and qweight_type in MMQ_QUANT_TYPES
and x.shape[0] > 64
):
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
BLOCK_SIZE = ggml_moe_get_block_size(qweight_type)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, BLOCK_SIZE, E
)
out = ggml_moe_a8(
x,
w1,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
qweight_type,
N,
top_k,
num_tokens,
)
out = act(out)
out = ggml_moe_a8(
out,
w2,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
qweight_type2,
w2.shape[1],
1,
num_tokens * top_k,
)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1)
)
# TODO(FlamingoPg): maybe we can use moe_sum_reduce here?
moe_sum(out, out_hidden_states)
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
out = ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
out = act(out)
out = ggml_moe_a8_vec(
out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1)
)
moe_sum(out, out_hidden_states)
else:
logger.warning_once(
"There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. "
)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1,) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = w1[ii]
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out)
expert_down = w2[ii]
current_state = fused_mul_mat_gguf(
out, expert_down, qweight_type2
).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
out_hidden_states[tok] = current_hidden_state
return out_hidden_states
def apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert hidden_size == qweight.shape[1] // type_size * block_size
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ggml_dequantize(
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.params_dtype = params_dtype
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
"shard_id": [],
"shard_id_map": {},
},
)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
qweight_type = Parameter(
torch.empty(len(output_partition_sizes), dtype=torch.uint8),
requires_grad=False,
)
set_weight_attrs(
qweight_type,
{
"is_gguf_weight_type": True,
"weight_type": 0,
"shard_weight_type": {},
"ignore_warning": True,
},
)
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)
def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
)
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)
def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}"
)
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros(
(concat_side, padded_side), dtype=dtype, device=qweight.device
)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
shard_id = layer.qweight.shard_id
if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight
result = []
for idx in shard_id:
start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(), qweight_type
)
)
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
},
)
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)
w13_qweight_type = Parameter(
torch.empty(1, dtype=torch.uint8), requires_grad=False
)
set_weight_attrs(
w13_qweight_type,
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
)
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)
tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
# gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
},
)
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)
w2_qweight_type = Parameter(
torch.empty(1, dtype=torch.uint8), requires_grad=False
)
set_weight_attrs(
w2_qweight_type,
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
)
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
assert self.fused_experts is None
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
output = fused_moe_gguf(
x=x,
w1=layer.w13_qweight,
w2=layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
qweight_type=layer.w13_qweight_type.weight_type,
qweight_type2=layer.w2_qweight_type.weight_type,
activation=moe_runner_config.activation,
)
return StandardCombineInput(hidden_states=output)
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]
return apply_gguf_embedding(
x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
)
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: list[torch.Tensor]
...@@ -140,7 +140,6 @@ from sglang.srt.utils import ( ...@@ -140,7 +140,6 @@ from sglang.srt.utils import (
is_sm100_supported, is_sm100_supported,
log_info_on_rank0, log_info_on_rank0,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cuda_arch, set_cuda_arch,
slow_rank_detector, slow_rank_detector,
xpu_has_xmx_support, xpu_has_xmx_support,
...@@ -858,8 +857,6 @@ class ModelRunner: ...@@ -858,8 +857,6 @@ class ModelRunner:
self.model_config = adjust_config_with_unaligned_cpu_tp( self.model_config = adjust_config_with_unaligned_cpu_tp(
self.model_config, self.load_config, self.tp_size self.model_config, self.load_config, self.tp_size
) )
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE: if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
if self.tp_rank == 0: if self.tp_rank == 0:
......
...@@ -95,7 +95,7 @@ from sglang.srt.environ import envs ...@@ -95,7 +95,7 @@ from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizeMethodBase pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1069,32 +1069,6 @@ def monkey_patch_p2p_access_check(): ...@@ -1069,32 +1069,6 @@ def monkey_patch_p2p_access_check():
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
def monkey_patch_vllm_gguf_config():
try:
from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig,
GGUFEmbeddingMethod,
GGUFLinearMethod,
)
except ImportError:
return
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
# patch to own VocabParallelEmbedding
return GGUFEmbeddingMethod(self)
return None
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
# number of open files # number of open files
resource_type = resource.RLIMIT_NOFILE resource_type = resource.RLIMIT_NOFILE
......
...@@ -197,7 +197,7 @@ suites = { ...@@ -197,7 +197,7 @@ suites = {
TestFile("test_bnb.py", 5), TestFile("test_bnb.py", 5),
TestFile("test_gptqmodel_dynamic.py", 102), TestFile("test_gptqmodel_dynamic.py", 102),
TestFile("test_vllm_dependency.py", 185), TestFile("test_vllm_dependency.py", 185),
# TestFile("test_gguf.py", 96), TestFile("test_gguf.py", 96),
], ],
# If the test cases take too long, considering adding them to nightly tests instead of per-commit tests # If the test cases take too long, considering adding them to nightly tests instead of per-commit tests
"nightly-1-gpu": [], "nightly-1-gpu": [],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment