Unverified Commit 71d1785f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove unnecessary `torch.full` in DeepSeek (#5601)

parent 3f87f831
...@@ -323,12 +323,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -323,12 +323,6 @@ class DeepseekV2MoE(nn.Module):
self, hidden_states: torch.Tensor, forward_mode: ForwardMode self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor: ) -> torch.Tensor:
shared_output = None shared_output = None
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if ( if (
forward_mode is not None forward_mode is not None
and not forward_mode.is_idle() and not forward_mode.is_idle()
...@@ -348,6 +342,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -348,6 +342,13 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1: if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment