Commit a8d6ba1e authored by wujl5's avatar wujl5
Browse files

[fix] moe长输入场景为缓解通讯压力关掉rms_quant融合

parent f4cd62b9
...@@ -231,9 +231,15 @@ class DeepseekV2MLP(nn.Module): ...@@ -231,9 +231,15 @@ class DeepseekV2MLP(nn.Module):
): ):
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None:
i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0)
i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0)
iqis = (i_q_gahter, i_s_gather)
else:
x = tensor_model_parallel_all_gather( x = tensor_model_parallel_all_gather(
x.contiguous(), 0 x.contiguous(), 0
) )
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis) gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
...@@ -1214,6 +1220,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1214,6 +1220,8 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor residual *= 1.0 / self.routed_scaling_factor
# Fully Connected # Fully Connected
enable_mla_cp = get_forward_context().enable_mla_cp
skip_moe_large_batch_size = enable_mla_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
...@@ -1222,9 +1230,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1222,9 +1230,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs update_input=update_hs
) )
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, if skip_moe_large_batch_size:
iqis=(_i_q, _i_s) hidden_states = self.mlp(hidden_states)
) else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment