Unverified Commit 11499535 authored by Atream's avatar Atream Committed by GitHub
Browse files

fix-gate-compile

parent e7882483
...@@ -125,7 +125,7 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -125,7 +125,7 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071 # adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model # This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def grouped_topk(hidden_states: torch.Tensor, def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
...@@ -225,9 +225,8 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): ...@@ -225,9 +225,8 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
hidden_states.type(torch.float32), self.weight.type(torch.float32), None hidden_states.type(torch.float32), self.weight.type(torch.float32), None
) )
return grouped_topk(hidden_states, logits, return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
self.top_k, self.norm_topk_prob, self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
self.n_group, self.topk_group)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device if device is None: device = self.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment