Unverified Commit 2f427491 authored by Li Hui's avatar Li Hui Committed by GitHub
Browse files

Fix topk inference performance reduce (#6474)

parent d8189660
...@@ -264,6 +264,8 @@ def biased_grouped_topk( ...@@ -264,6 +264,8 @@ def biased_grouped_topk(
# TODO merge into kernel for this branch # TODO merge into kernel for this branch
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
# TODO will fuse this into kernel, thus use slow manual operation now # TODO will fuse this into kernel, thus use slow manual operation now
if num_token_non_padded is None:
return topk_weights, topk_ids
torch.compile( torch.compile(
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend() _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
)(topk_ids, num_token_non_padded) )(topk_ids, num_token_non_padded)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment