Commit 475dcaa0 authored by yangql's avatar yangql
Browse files

修复deepseek moe模型的awq量化推理bug和精度问题

parent efd51772
......@@ -11,7 +11,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
......
......@@ -288,6 +288,10 @@ def get_model_architecture(
os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
#针对使用dtype为fp16的情况的量化默认关闭"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
if model_config.quantization in {"awq", "awq_marlin", "moe_wna16"}:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '0'
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
......
......@@ -385,9 +385,12 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......@@ -429,9 +432,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else:
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment