"src/vscode:/vscode.git/clone" did not exist on "7200daa412b9d0738e655fbac99077f9b899d1f1"
Unverified Commit 1228f7ca authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix gptq for moe layers (#2300)


Co-authored-by: default avatarroot <me@zhyncs.com>
parent fda628d8
...@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix): ...@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
return None return None
def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None
def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None
def apply_monkey_patches(): def apply_monkey_patches():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
# Apply patches when module is imported # Apply patches when module is imported
......
...@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
...@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module): ...@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue continue
# Skip loading kv_scale from ckpts towards new design. # Skip loading kv_scale from ckpts towards new design.
if name.endswith(".kv_scale") and name not in params_dict: if name.endswith(".kv_scale") and name not in params_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment