Unverified Commit ef6e0e71 authored by CSWYF3634076's avatar CSWYF3634076 Committed by GitHub
Browse files

[Bugfix][Model]fix ernie45 moe gate&bias dtype to float32 (#25936)


Signed-off-by: default avatarwangyafeng <wangyafeng@baidu.com>
parent 1ad3aca6
...@@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts, config.moe_num_experts,
bias=False, bias=False,
params_dtype=torch.float32,
quant_config=None, quant_config=None,
prefix=f"{prefix}.gate") prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter( self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.moe_num_experts)) torch.empty(config.moe_num_experts, dtype=torch.float32))
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.moe_num_experts, num_experts=config.moe_num_experts,
...@@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module):
if self.has_shared_experts: if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
......
...@@ -199,7 +199,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -199,7 +199,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
assert config.moe_num_experts[0] == config.moe_num_experts[1] assert config.moe_num_experts[0] == config.moe_num_experts[1]
self.e_score_correction_bias = nn.Parameter( self.e_score_correction_bias = nn.Parameter(
torch.empty(2, config.moe_num_experts[0])) torch.empty(2, config.moe_num_experts[0], dtype=torch.float32))
assert text_moe_layer_start_index <= text_moe_layer_end_index assert text_moe_layer_start_index <= text_moe_layer_end_index
...@@ -209,6 +209,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -209,6 +209,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config.hidden_size, config.hidden_size,
config.moe_num_experts[0], config.moe_num_experts[0],
bias=False, bias=False,
params_dtype=torch.float32,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.text_experts_gate") prefix=f"{prefix}.text_experts_gate")
...@@ -238,6 +239,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -238,6 +239,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config.hidden_size, config.hidden_size,
config.moe_num_experts[1], config.moe_num_experts[1],
bias=False, bias=False,
params_dtype=torch.float32,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.vision_experts_gate") prefix=f"{prefix}.vision_experts_gate")
...@@ -288,7 +290,8 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -288,7 +290,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
if visual_token_mask is not None and visual_token_mask.all(): if visual_token_mask is not None and visual_token_mask.all():
# only vision modal input # only vision modal input
router_logits, _ = self.vision_experts_gate(hidden_states) router_logits, _ = self.vision_experts_gate(
hidden_states.to(dtype=torch.float32))
final_hidden_states = self.vision_experts( final_hidden_states = self.vision_experts(
hidden_states=hidden_states, router_logits=router_logits) hidden_states=hidden_states, router_logits=router_logits)
elif visual_token_mask is not None and visual_token_mask.any(): elif visual_token_mask is not None and visual_token_mask.any():
...@@ -303,19 +306,21 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -303,19 +306,21 @@ class Ernie4_5_VLMoeMoE(nn.Module):
vision_hidden_states = hidden_states[visual_token_mask].reshape( vision_hidden_states = hidden_states[visual_token_mask].reshape(
-1, self.hidden_size) -1, self.hidden_size)
text_router_logits, _ = self.text_experts_gate(text_hidden_states) text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32))
final_hidden_states[text_token_mask] = self.text_experts( final_hidden_states[text_token_mask] = self.text_experts(
hidden_states=text_hidden_states, hidden_states=text_hidden_states,
router_logits=text_router_logits).flatten() router_logits=text_router_logits).flatten()
vision_router_logits, _ = self.vision_experts_gate( vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states) vision_hidden_states.to(dtype=torch.float32))
final_hidden_states[visual_token_mask] = self.vision_experts( final_hidden_states[visual_token_mask] = self.vision_experts(
hidden_states=vision_hidden_states, hidden_states=vision_hidden_states,
router_logits=vision_router_logits).flatten() router_logits=vision_router_logits).flatten()
else: else:
# only text modal input # only text modal input
text_router_logits, _ = self.text_experts_gate(hidden_states) text_router_logits, _ = self.text_experts_gate(
hidden_states.to(dtype=torch.float32))
final_hidden_states = self.text_experts( final_hidden_states = self.text_experts(
hidden_states=hidden_states, router_logits=text_router_logits) hidden_states=hidden_states, router_logits=text_router_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment