Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
da47621c
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "1b2ff4fb7f05ec82128765c366e6f75f4e3f05f7"
Unverified
Commit
da47621c
authored
Jun 13, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Minor speedup topk postprocessing (#7058)
parent
22a6b9fc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
8 deletions
+16
-8
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+16
-8
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
da47621c
...
@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
...
@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
topk_ids
[
indices
>=
num_token_non_padded
,
:]
=
-
1
topk_ids
[
indices
>=
num_token_non_padded
,
:]
=
-
1
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_biased_grouped_topk_postprocess
(
topk_ids
,
expert_location_dispatch_info
,
num_token_non_padded
):
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
_mask_topk_ids_padded_region
(
topk_ids
,
num_token_non_padded
)
return
topk_ids
def
biased_grouped_topk
(
def
biased_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
@@ -282,14 +291,13 @@ def biased_grouped_topk(
...
@@ -282,14 +291,13 @@ def biased_grouped_topk(
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
)
)
# TODO merge into kernel for this branch
# TODO merge into kernel
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
if
(
expert_location_dispatch_info
is
not
None
)
or
(
# TODO will fuse this into kernel, thus use slow manual operation now
num_token_non_padded
is
not
None
if
num_token_non_padded
is
None
:
):
return
topk_weights
,
topk_ids
topk_ids
=
_biased_grouped_topk_postprocess
(
torch
.
compile
(
topk_ids
,
expert_location_dispatch_info
,
num_token_non_padded
_mask_topk_ids_padded_region
,
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
)(
topk_ids
,
num_token_non_padded
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
else
:
else
:
biased_grouped_topk_fn
=
(
biased_grouped_topk_fn
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment