Commit a8d6ba1e authored by wujl5's avatar wujl5
Browse files

[fix] moe长输入场景为缓解通讯压力关掉rms_quant融合

parent f4cd62b9
......@@ -231,9 +231,15 @@ class DeepseekV2MLP(nn.Module):
):
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None:
i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0)
i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0)
iqis = (i_q_gahter, i_s_gather)
else:
x = tensor_model_parallel_all_gather(
x.contiguous(), 0
)
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
......@@ -1214,6 +1220,8 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
enable_mla_cp = get_forward_context().enable_mla_cp
skip_moe_large_batch_size = enable_mla_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
......@@ -1222,9 +1230,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s)
)
if skip_moe_large_batch_size:
hidden_states = self.mlp(hidden_states)
else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment