Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c227ee3
Unverified
Commit
0c227ee3
authored
Feb 21, 2025
by
zixuanzhang226
Committed by
GitHub
Feb 21, 2025
Browse files
feat: update grouped_topk to support softmax and sigmoid (#3680)
parent
5c54ef03
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+10
-3
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
0c227ee3
...
@@ -75,7 +75,6 @@ def fused_topk(
...
@@ -75,7 +75,6 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
# This is used by the Deepseek V2/V3/R1 series models
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
grouped_topk
(
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -84,10 +83,17 @@ def grouped_topk(
...
@@ -84,10 +83,17 @@ def grouped_topk(
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Scoring function '
{
scoring_func
}
' is not supported."
)
num_token
=
scores
.
shape
[
0
]
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
...
@@ -111,6 +117,7 @@ def grouped_topk(
...
@@ -111,6 +117,7 @@ def grouped_topk(
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
# DeepSeek V2/V3/R1 uses biased_grouped_top
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
biased_grouped_topk
(
def
biased_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -165,7 +172,7 @@ def select_experts(
...
@@ -165,7 +172,7 @@ def select_experts(
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
torch_native
:
bool
=
False
,
torch_native
:
bool
=
False
,
):
):
# Dee
k
Seek
v2 uses
grouped_top
_k
# Dee
p
Seek
V2/V3/R1 uses biased_
grouped_top
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
assert
num_expert_group
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment