Unverified Commit da47621c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor speedup topk postprocessing (#7058)

parent 22a6b9fc
......@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
topk_ids[indices >= num_token_non_padded, :] = -1
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
):
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_ids
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -282,14 +291,13 @@ def biased_grouped_topk(
num_fused_shared_experts,
routed_scaling_factor,
)
# TODO merge into kernel for this branch
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
# TODO will fuse this into kernel, thus use slow manual operation now
if num_token_non_padded is None:
return topk_weights, topk_ids
torch.compile(
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
)(topk_ids, num_token_non_padded)
# TODO merge into kernel
if (expert_location_dispatch_info is not None) or (
num_token_non_padded is not None
):
topk_ids = _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
)
return topk_weights, topk_ids
else:
biased_grouped_topk_fn = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment