"tests/python/common/test_heterograph-apply-edges.py" did not exist on "5eca59d8ddb3e7b3a391f3786b6de2c24bd3c499"
Unverified Commit c6d549e7 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Multiple tiny code cleanups (#4608)

parent 3c09548d
...@@ -185,7 +185,6 @@ class DeepEPDispatcher: ...@@ -185,7 +185,6 @@ class DeepEPDispatcher:
previous_event=None, previous_event=None,
num_max_dispatch_tokens_per_rank: int = 128, num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
self.hidden_shape = hidden_states.shape
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
# Todo: enable low latency dispatch # Todo: enable low latency dispatch
if True: # not forward_mode.is_decode(): if True: # not forward_mode.is_decode():
...@@ -375,7 +374,7 @@ class DeepEPDispatcher: ...@@ -375,7 +374,7 @@ class DeepEPDispatcher:
hidden_states, self.topk_idx, self.topk_weights, self.handle hidden_states, self.topk_idx, self.topk_weights, self.handle
) )
self.handle = None self.handle = None
return hidden_states.view(self.hidden_shape) return hidden_states
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None): def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
combined_x, _, event = self.buffer_normal.combine( combined_x, _, event = self.buffer_normal.combine(
......
...@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_mode) return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
...@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states
def forward_deepep( def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None shared_output = None
topk_idx = torch.full( topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
...@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment