Unverified Commit 3ded4b21 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Revert "feat: update grouped_topk to support softmax and sigmoid" (#4505)

parent f4d7ab7a
...@@ -88,6 +88,7 @@ def fused_topk( ...@@ -88,6 +88,7 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk( def grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -96,17 +97,10 @@ def grouped_topk( ...@@ -96,17 +97,10 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax",
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1)
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = ( group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values scores.view(num_token, num_expert_group, -1).max(dim=-1).values
...@@ -130,7 +124,6 @@ def grouped_topk( ...@@ -130,7 +124,6 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# DeepSeek V2/V3/R1 uses biased_grouped_top
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk( def biased_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -185,7 +178,7 @@ def select_experts( ...@@ -185,7 +178,7 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False, torch_native: bool = False,
): ):
# DeepSeek V2/V3/R1 uses biased_grouped_top # DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment