Unverified Commit 72bfb0ba authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek MoE layer to unify the two forward branches (#6325)

parent 15521495
...@@ -194,6 +194,14 @@ class MoEGate(nn.Module): ...@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
return logits return logits
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
...@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size() self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.n_routed_experts self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
self.num_expert_group = config.n_group self.num_expert_group = config.n_group
...@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module): ...@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
return_recv_hook=True, return_recv_hook=True,
) )
@property
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]: if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
return self.forward_normal(hidden_states) forward_mode, hidden_states
else: ):
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
final_hidden_states = self.experts( else:
hidden_states=hidden_states, router_logits=router_logits router_logits = None
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_deepep( if (self.n_share_experts_fusion == 0) and (
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch (not self._enable_deepep_moe)
) -> torch.Tensor: or is_non_idle_and_non_empty(forward_mode, hidden_states)
forward_mode = forward_batch.forward_mode
shared_output = None
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
): ):
# router_logits: (num_tokens, n_experts) shared_output = self.shared_experts(hidden_states)
router_logits = self.gate(hidden_states) else:
shared_output = self._forward_shared_experts(hidden_states) shared_output = None
if self._enable_deepep_moe and (router_logits is not None):
topk_weights, topk_idx = select_experts( topk_weights, topk_idx = select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
...@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty( topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device (0, self.top_k), dtype=torch.float32, device=hidden_states.device
) )
if self.ep_size > 1:
if self._enable_deepep_moe and (self.ep_size > 1):
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
( (
hidden_states, hidden_states,
...@@ -357,6 +356,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -357,6 +356,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights, topk_weights,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
if self._enable_deepep_moe:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
...@@ -368,25 +369,28 @@ class DeepseekV2MoE(nn.Module): ...@@ -368,25 +369,28 @@ class DeepseekV2MoE(nn.Module):
num_recv_tokens_per_expert=num_recv_tokens_per_expert, num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
if self.ep_size > 1: else:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self._enable_deepep_moe and (self.ep_size > 1):
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, final_hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
forward_mode, forward_mode,
) )
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
return final_hidden_states if (not self._enable_deepep_moe) and (self.tp_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
def _forward_shared_experts(self, hidden_states): return final_hidden_states
if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment