Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
044c3159
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fead3ba3867d09a2ac0e21a2e7395be5d70c02d1"
Unverified
Commit
044c3159
authored
Mar 28, 2025
by
Qingquan Song
Committed by
GitHub
Mar 28, 2025
Browse files
Make torch compile configurable for biased_grouped_topk (#4749)
parent
4db29e82
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
2 deletions
+29
-2
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+29
-2
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
044c3159
...
@@ -129,8 +129,7 @@ def grouped_topk(
...
@@ -129,8 +129,7 @@ def grouped_topk(
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
biased_grouped_topk_impl
(
def
biased_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
...
@@ -171,6 +170,34 @@ def biased_grouped_topk(
...
@@ -171,6 +170,34 @@ def biased_grouped_topk(
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
biased_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
compiled
:
bool
=
True
,
):
biased_grouped_topk_fn
=
(
torch
.
compile
(
biased_grouped_topk_impl
,
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
if
compiled
else
biased_grouped_topk_impl
)
return
biased_grouped_topk_fn
(
hidden_states
,
gating_output
,
correction_bias
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
)
def
select_experts
(
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment