Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f715f51
Unverified
Commit
2f715f51
authored
Jun 07, 2025
by
fzyzcjy
Committed by
GitHub
Jun 07, 2025
Browse files
Minor compile fused topk (#6944)
parent
d664ca18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
1 deletion
+17
-1
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+17
-1
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
2f715f51
...
@@ -89,6 +89,23 @@ def fused_topk(
...
@@ -89,6 +89,23 @@ def fused_topk(
)
)
del
token_expert_indicies
del
token_expert_indicies
return
_fused_topk_postprocess
(
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
renormalize
=
renormalize
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
num_token_non_padded
=
num_token_non_padded
,
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_fused_topk_postprocess
(
topk_weights
,
topk_ids
,
renormalize
,
expert_location_dispatch_info
,
num_token_non_padded
,
):
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
@@ -313,7 +330,6 @@ def select_experts(
...
@@ -313,7 +330,6 @@ def select_experts(
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
):
):
router_logits
,
correction_bias
=
(
router_logits
,
correction_bias
=
(
expert_location_dispatch
.
transform_select_experts_inputs
(
expert_location_dispatch
.
transform_select_experts_inputs
(
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment