Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
38076dea
Unverified
Commit
38076dea
authored
Apr 15, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 14, 2025
Browse files
apply fused moe gate in ds v3/r1 (#5371)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
5e0a9b09
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
16 deletions
+37
-16
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+37
-16
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
38076dea
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
math
import
os
import
os
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
...
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
...
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
if
_is_cuda
:
from
sgl_kernel
import
moe_fused_gate
expert_distribution_recorder
=
ExpertDistributionRecorder
()
expert_distribution_recorder
=
ExpertDistributionRecorder
()
...
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
...
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
is_power_of_two
(
n
):
return
n
>
0
and
math
.
log2
(
n
).
is_integer
()
def
biased_grouped_topk
(
def
biased_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
@@ -220,23 +227,37 @@ def biased_grouped_topk(
...
@@ -220,23 +227,37 @@ def biased_grouped_topk(
compiled
:
bool
=
True
,
compiled
:
bool
=
True
,
n_share_experts_fusion
:
int
=
0
,
n_share_experts_fusion
:
int
=
0
,
):
):
biased_grouped_topk_fn
=
(
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
torch
.
compile
(
if
(
biased_grouped_topk_impl
,
dynamic
=
True
,
backend
=
get_compiler_backend
()
_is_cuda
and
n_share_experts_fusion
==
0
and
is_power_of_two
(
correction_bias
.
shape
[
0
])
):
return
moe_fused_gate
(
gating_output
,
correction_bias
,
num_expert_group
,
topk_group
,
topk
,
)
else
:
biased_grouped_topk_fn
=
(
torch
.
compile
(
biased_grouped_topk_impl
,
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
if
compiled
else
biased_grouped_topk_impl
)
return
biased_grouped_topk_fn
(
hidden_states
,
gating_output
,
correction_bias
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
n_share_experts_fusion
=
n_share_experts_fusion
,
)
)
if
compiled
else
biased_grouped_topk_impl
)
return
biased_grouped_topk_fn
(
hidden_states
,
gating_output
,
correction_bias
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
n_share_experts_fusion
=
n_share_experts_fusion
,
)
def
select_experts
(
def
select_experts
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment