Unverified Commit 3ae78a09 authored by Arcmoon's avatar Arcmoon Committed by GitHub
Browse files

Add gptq quantization model support (#141)

parent ccbe1e67
......@@ -19,10 +19,9 @@ class RadixAttention(nn.Module):
head_dim,
scaling,
num_kv_heads,
layer_id,
layer_id
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
......
......@@ -12,10 +12,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
'gptq': GPTQConfig}
logger = logging.getLogger("model_runner")
......@@ -280,8 +283,10 @@ class ModelRunner:
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
# TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config)
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method'])
if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}")
quant_config = quant_config_class.from_config(hf_quant_config)
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class(
......
......@@ -34,6 +34,7 @@ class QWenMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -41,12 +42,14 @@ class QWenMLP(nn.Module):
2 * [intermediate_size],
bias=False,
gather_output=False,
linear_method=linear_method
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
)
if hidden_act != "silu":
raise ValueError(
......@@ -71,6 +74,7 @@ class QWenAttention(nn.Module):
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None
):
super().__init__()
self.hidden_size = hidden_size
......@@ -82,13 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size, self.head_dim, self.total_num_heads, bias=True
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -121,7 +130,7 @@ class QWenAttention(nn.Module):
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig, layer_id):
def __init__(self, config: QWenConfig, layer_id, linear_method=None):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -134,11 +143,12 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id,
linear_method=linear_method
)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
def forward(
self,
......@@ -165,7 +175,7 @@ class QWenBlock(nn.Module):
class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
def __init__(self, config: QWenConfig, linear_method=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
......@@ -176,7 +186,7 @@ class QWenModel(nn.Module):
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i) for i in range(config.num_hidden_layers)]
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -202,7 +212,7 @@ class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig, linear_method=None):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
self.transformer = QWenModel(config, linear_method=linear_method)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
......@@ -219,9 +229,6 @@ class QWenLMHeadModel(nn.Module):
)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
self,
model_name_or_path: str,
......
......@@ -259,4 +259,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
return image
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment