Unverified Commit baaa2b24 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Bugfix] fix moe_wna16 get_quant_method (#12648)

Fix https://github.com/vllm-project/vllm/issues/12647
The `get_quant_method` of `moe_wna16` always return moe method,
GPTQ-based linear method or AWQ-based linear method, even when the
target module is attention layer.


https://github.com/vllm-project/vllm/blob/baeded25699f9f4851843306f27f685c4d4ee7c5/vllm/attention/layer.py#L86-L92

Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
parent b4e5c033
...@@ -6,16 +6,13 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group ...@@ -6,16 +6,13 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.awq import (AWQConfig, from vllm.model_executor.layers.quantization.awq import AWQConfig
AWQLinearMethod) from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig, from vllm.model_executor.layers.quantization.gptq import GPTQConfig
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig, GPTQMarlinLinearMethod) GPTQMarlinConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -131,18 +128,18 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -131,18 +128,18 @@ class MoeWNA16Config(QuantizationConfig):
else: else:
if self.linear_quant_method == "gptq": if self.linear_quant_method == "gptq":
if self.use_marlin: if self.use_marlin:
return GPTQMarlinLinearMethod( return GPTQMarlinConfig.from_config(
GPTQMarlinConfig.from_config(self.full_config)) self.full_config).get_quant_method(layer, prefix)
else: else:
return GPTQLinearMethod( return GPTQConfig.from_config(
GPTQConfig.from_config(self.full_config)) self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq": elif self.linear_quant_method == "awq":
if self.use_marlin: if self.use_marlin:
return AWQMarlinLinearMethod( return AWQMarlinConfig.from_config(
AWQMarlinConfig.from_config(self.full_config)) self.full_config).get_quant_method(layer, prefix)
else: else:
return AWQLinearMethod( return AWQConfig.from_config(
AWQConfig.from_config(self.full_config)) self.full_config).get_quant_method(layer, prefix)
else: else:
raise ValueError("moe_wna16 only support gptq and awq.") raise ValueError("moe_wna16 only support gptq and awq.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment