Unverified Commit 2f715f51 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor compile fused topk (#6944)

parent d664ca18
...@@ -89,6 +89,23 @@ def fused_topk( ...@@ -89,6 +89,23 @@ def fused_topk(
) )
del token_expert_indicies del token_expert_indicies
return _fused_topk_postprocess(
topk_weights=topk_weights,
topk_ids=topk_ids,
renormalize=renormalize,
expert_location_dispatch_info=expert_location_dispatch_info,
num_token_non_padded=num_token_non_padded,
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _fused_topk_postprocess(
topk_weights,
topk_ids,
renormalize,
expert_location_dispatch_info,
num_token_non_padded,
):
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
...@@ -313,7 +330,6 @@ def select_experts( ...@@ -313,7 +330,6 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits, router_logits=router_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment