Unverified Commit e5192879 authored by CSWYF3634076's avatar CSWYF3634076 Committed by GitHub
Browse files

[Model][Bugfix] fix ernie45 vl run failed from shared experts optimization (#26885)


Signed-off-by: default avatarwangyafeng <wangyafeng@baidu.com>
parent d2740faf
...@@ -341,7 +341,10 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -341,7 +341,10 @@ class Ernie4_5_VLMoeMoE(nn.Module):
# text and vision modals input # text and vision modals input
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask text_token_mask = ~visual_token_mask
final_hidden_states = torch.zeros_like(hidden_states) final_experts_hidden_states = torch.zeros_like(hidden_states)
final_shared_ouput = (
torch.zeros_like(hidden_states) if self.has_shared_experts else None
)
text_hidden_states = hidden_states[text_token_mask].reshape( text_hidden_states = hidden_states[text_token_mask].reshape(
-1, self.hidden_size -1, self.hidden_size
...@@ -353,16 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -353,16 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module):
text_router_logits, _ = self.text_experts_gate( text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32) text_hidden_states.to(dtype=torch.float32)
) )
final_hidden_states[text_token_mask] = self.text_experts( text_shared_ouput, text_experts_output = self.text_experts(
hidden_states=text_hidden_states, router_logits=text_router_logits hidden_states=text_hidden_states, router_logits=text_router_logits
).flatten() )
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
if self.has_shared_experts:
final_shared_ouput[text_token_mask] = text_shared_ouput.flatten()
vision_router_logits, _ = self.vision_experts_gate( vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states.to(dtype=torch.float32) vision_hidden_states.to(dtype=torch.float32)
) )
final_hidden_states[visual_token_mask] = self.vision_experts( vision_shared_ouput, vision_experts_output = self.vision_experts(
hidden_states=vision_hidden_states, router_logits=vision_router_logits hidden_states=vision_hidden_states, router_logits=vision_router_logits
).flatten() )
final_experts_hidden_states[visual_token_mask] = (
vision_experts_output.flatten()
)
if self.has_shared_experts:
final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten()
final_hidden_states = (final_shared_ouput, final_experts_hidden_states)
else: else:
# only text modal input # only text modal input
text_router_logits, _ = self.text_experts_gate( text_router_logits, _ = self.text_experts_gate(
...@@ -374,7 +387,11 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -374,7 +387,11 @@ class Ernie4_5_VLMoeMoE(nn.Module):
) )
if self.has_shared_experts: if self.has_shared_experts:
# for shared_experts model
final_hidden_states = final_hidden_states[0] + final_hidden_states[1] final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
else:
# for not shared_experts model
final_hidden_states = final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = ( final_hidden_states = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment