Unverified Commit a9dd3ec3 authored by erictanjn's avatar erictanjn Committed by GitHub
Browse files

fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)

parent 45bc170b
......@@ -397,7 +397,9 @@ def grouped_topk_gpu(
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
......@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
_, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment