Commit 18f030d9 authored by zhuwenwen's avatar zhuwenwen
Browse files

修复在fp16下的数值越界导致的精度问题

parent 7bf6c98f
......@@ -164,29 +164,24 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
# if hidden_states.dtype != torch.float16:
# final_hidden_states = self.experts(
# hidden_states=hidden_states,
# router_logits=router_logits) * self.routed_scaling_factor
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# final_hidden_states = self.experts(hidden_states=hidden_states,
# router_logits=router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
# if shared_output is not None:
# if hidden_states.dtype != torch.float16:
# final_hidden_states = final_hidden_states + shared_output
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
......@@ -593,29 +588,29 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states,
)
# if hidden_states.dtype == torch.float16:
# # Fix FP16 overflow
# # We scale both hidden_states and residual before
# # rmsnorm, and rmsnorm result would not affect by scale.
# hidden_states *= 1. / self.routed_scaling_factor
# if self.layer_idx == 0:
# # The residual is shared by all layers, we only scale it on
# # first layer.
# residual *= 1. / self.routed_scaling_factor
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
# if isinstance(self.mlp,
# DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# # Fix FP16 overflow
# # Scaling the DeepseekV2MLP output, it is the input of
# # input_layernorm of next decoder layer.
# # The scaling of DeepseekV2MOE output would be done in the forward
# # of DeepseekV2MOE
# hidden_states *= 1. / self.routed_scaling_factor
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment