Unverified Commit 72bfb0ba authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek MoE layer to unify the two forward branches (#6325)

parent 15521495
......@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
return logits
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
class DeepseekV2MoE(nn.Module):
def __init__(
......@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
),
)
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
self.num_expert_group = config.n_group
......@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
return_recv_hook=True,
)
@property
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
shared_output = None
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states
):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
else:
router_logits = None
if (self.n_share_experts_fusion == 0) and (
(not self._enable_deepep_moe)
or is_non_idle_and_non_empty(forward_mode, hidden_states)
):
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
if self._enable_deepep_moe and (router_logits is not None):
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
......@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
if self._enable_deepep_moe and (self.ep_size > 1):
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
......@@ -357,36 +356,41 @@ class DeepseekV2MoE(nn.Module):
topk_weights,
forward_mode=forward_mode,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
if self.ep_size > 1:
if self._enable_deepep_moe:
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
else:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self._enable_deepep_moe and (self.ep_size > 1):
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states
if (not self._enable_deepep_moe) and (self.tp_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
def _forward_shared_experts(self, hidden_states):
if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
return final_hidden_states
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment