Unverified Commit 541a985f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fuse routed_scaling_factor in DeepSeek (#6710)

parent 5170b010
...@@ -526,9 +526,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -526,9 +526,13 @@ class DeepseekV2MoE(nn.Module):
def op_output(self, state): def op_output(self, state):
final_hidden_states = state.pop("hidden_states_after_combine") final_hidden_states = state.pop("hidden_states_after_combine")
final_hidden_states *= self.routed_scaling_factor
if (s := state.pop("shared_output")) is not None: if (shared_output := state.pop("shared_output")) is not None:
final_hidden_states = final_hidden_states + s x = shared_output
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
final_hidden_states *= self.routed_scaling_factor
state.hidden_states_mlp_output = final_hidden_states state.hidden_states_mlp_output = final_hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment