Commit bb94d2e5 authored by yangql's avatar yangql
Browse files

增加fused-moe int4/int8的支持,以及deepseek精度问题的修复

parent 087254b9
...@@ -164,21 +164,28 @@ class DeepseekV2MoE(nn.Module): ...@@ -164,21 +164,28 @@ class DeepseekV2MoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16: # if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( # final_hidden_states = self.experts(
hidden_states=hidden_states, # hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor # router_logits=router_logits) * self.routed_scaling_factor
else: # else:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
final_hidden_states = self.experts(hidden_states=hidden_states, # final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) # router_logits=router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: final_hidden_states = final_hidden_states + shared_output
final_hidden_states = final_hidden_states + shared_output
else: # if shared_output is not None:
# This is a special case to avoid FP16 overflow # if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output \ # final_hidden_states = final_hidden_states + shared_output
* (1. / self.routed_scaling_factor) # else:
# # This is a special case to avoid FP16 overflow
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states) final_hidden_states)
...@@ -571,18 +578,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -571,18 +578,18 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
if isinstance(self.mlp, DeepseekV2MoE) and \ # if isinstance(self.mlp, DeepseekV2MoE) and \
hidden_states.dtype == torch.float16: # hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor # hidden_states *= 1. / self.routed_scaling_factor
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and \ # if isinstance(self.mlp, DeepseekV2MLP) and \
hidden_states.dtype == torch.float16: # hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor # hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor # residual *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment