Unverified Commit da47621c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor speedup topk postprocessing (#7058)

parent 22a6b9fc
...@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region( ...@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
topk_ids[indices >= num_token_non_padded, :] = -1 topk_ids[indices >= num_token_non_padded, :] = -1
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
):
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_ids
def biased_grouped_topk( def biased_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -282,14 +291,13 @@ def biased_grouped_topk( ...@@ -282,14 +291,13 @@ def biased_grouped_topk(
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
) )
# TODO merge into kernel for this branch # TODO merge into kernel
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) if (expert_location_dispatch_info is not None) or (
# TODO will fuse this into kernel, thus use slow manual operation now num_token_non_padded is not None
if num_token_non_padded is None: ):
return topk_weights, topk_ids topk_ids = _biased_grouped_topk_postprocess(
torch.compile( topk_ids, expert_location_dispatch_info, num_token_non_padded
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend() )
)(topk_ids, num_token_non_padded)
return topk_weights, topk_ids return topk_weights, topk_ids
else: else:
biased_grouped_topk_fn = ( biased_grouped_topk_fn = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment