Unverified Commit c982ac57 authored by Concurrensee's avatar Concurrensee Committed by GitHub
Browse files

[Bugfix] Fix FP16 overflow for DeepSeek V2 (#13232)


Signed-off-by: default avatarYida Wu <yida.wu@amd.com>
parent 4290b704
......@@ -155,11 +155,21 @@ class DeepseekV2MoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# This is a special case to avoid FP16 overflow
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# This is a special case to avoid FP16 overflow
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
......@@ -531,6 +541,7 @@ class DeepseekV2DecoderLayer(nn.Module):
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
......@@ -551,9 +562,18 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
if isinstance(self.mlp, DeepseekV2MoE) and \
hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and \
hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor
return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment