Unverified Commit f35ec461 authored by Steve Luo's avatar Steve Luo Committed by GitHub
Browse files

[Bugfix] Fix deepseekv3 gate bias error (#12002)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
parent 289b5191
...@@ -497,7 +497,10 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -497,7 +497,10 @@ def grouped_topk(hidden_states: torch.Tensor,
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0)) # Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, group_scores = scores.view(num_token, num_expert_group,
...@@ -510,10 +513,16 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -510,10 +513,16 @@ def grouped_topk(hidden_states: torch.Tensor,
num_token, num_expert_group, num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk, if e_score_correction_bias is not None:
dim=-1, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
sorted=False) # Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment