Unverified Commit b2388433 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add back DeepSeek non-TBO branches (#6578)

parent a38376fa
......@@ -324,6 +324,104 @@ class DeepseekV2MoE(nn.Module):
if name not in ["correction_bias"]
]
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not self._enable_deepep_moe:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
shared_output = None
if is_non_idle_and_non_empty(forward_mode, hidden_states):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states
def _forward_shared_experts(self, hidden_states):
if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
def op_gate(self, state):
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
......@@ -1353,17 +1451,29 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
return execute_operations(
inputs=dict(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
),
operations=compute_layer_operations(self),
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
def op_comm_prepare_attn(
self,
state,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment