Unverified Commit a9dd3ec3 authored by erictanjn's avatar erictanjn Committed by GitHub
Browse files

fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)

parent 45bc170b
...@@ -397,7 +397,9 @@ def grouped_topk_gpu( ...@@ -397,7 +397,9 @@ def grouped_topk_gpu(
.reshape(num_token, -1) .reshape(num_token, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
if num_fused_shared_experts: if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint( topk_ids[:, -1] = torch.randint(
low=num_experts, low=num_experts,
...@@ -486,7 +488,9 @@ def biased_grouped_topk_impl( ...@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
tmp_scores = scores_for_choice.masked_fill( tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf") ~score_mask.bool(), float("-inf")
) # [n, e] ) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) _, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
topk_weights = scores.gather(1, topk_ids) topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts: if num_fused_shared_experts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment