Unverified Commit adca585b authored by yulei's avatar yulei Committed by GitHub
Browse files

[DeepEP] Reduce routed scaling overhead (#5277)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 39d90449
...@@ -337,8 +337,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -337,8 +337,7 @@ class DeepseekV2MoE(nn.Module):
topk_weights, topk_weights,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
final_hidden_states = ( final_hidden_states = self.experts(
self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
reorder_topk_ids=reorder_topk_ids, reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr, seg_indptr=seg_indptr,
...@@ -346,8 +345,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -346,8 +345,6 @@ class DeepseekV2MoE(nn.Module):
expected_m=expected_m, expected_m=expected_m,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
* self.routed_scaling_factor
)
if self.ep_size > 1: if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, final_hidden_states,
...@@ -355,6 +352,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -355,6 +352,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights, topk_weights,
forward_mode, forward_mode,
) )
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment