Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e984d507
Unverified
Commit
e984d507
authored
Jun 24, 2025
by
valarLip
Committed by
GitHub
Jun 24, 2025
Browse files
enable aiter_biased_grouped_topk kernel (#7423)
parent
755f3147
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
2 deletions
+29
-2
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+26
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-1
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
e984d507
...
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
get_compiler_backend
,
is_cpu
,
is_cuda
,
...
...
@@ -38,6 +39,7 @@ from sglang.srt.utils import (
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
...
...
@@ -46,6 +48,11 @@ if _is_cuda:
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
if
_use_aiter
:
try
:
from
aiter
import
biased_grouped_topk
as
aiter_biased_grouped_topk
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
def
fused_topk_torch_native
(
...
...
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
topk_ids
,
expert_location_dispatch_info
,
num_token_non_padded
)
return
topk_weights
,
topk_ids
elif
_use_aiter
:
token
=
gating_output
.
shape
[
0
]
device
=
gating_output
.
device
assert
(
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
]
),
f
"Number of tokens mismatch: hidden_states.shape[0] =
{
hidden_states
.
shape
[
0
]
}
, gating_output.shape[0] =
{
gating_output
.
shape
[
0
]
}
"
topk_weights
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_ids
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
aiter_biased_grouped_topk
(
gating_output
,
correction_bias
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
routed_scaling_factor
,
)
return
topk_weights
,
topk_ids
else
:
biased_grouped_topk_fn
=
(
torch
.
compile
(
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e984d507
...
...
@@ -421,7 +421,7 @@ class CudaGraphRunner:
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing batches (
{
avail_mem
=
:.
2
f
}
GB)"
f
"Capturing batches (
{
bs
=
}
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e984d507
...
...
@@ -388,7 +388,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
not
_is_cuda
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment