Unverified Commit bdbb8d00 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[perf] slightly imporve DeepSeek-R1-FP4 TP8 (#7481)

parent 34c3f9b2
...@@ -362,12 +362,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -362,12 +362,14 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment