Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
48c1fa7b
"examples/pytorch/vscode:/vscode.git/clone" did not exist on "701b746b82210a23a8db7b87af080a3a9ec28493"
Unverified
Commit
48c1fa7b
authored
Jul 18, 2025
by
jianan-gu
Committed by
GitHub
Jul 17, 2025
Browse files
[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)
parent
8aa5ae6b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
4 deletions
+35
-4
python/sglang/srt/configs/update_config.py
python/sglang/srt/configs/update_config.py
+3
-1
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+13
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+6
-0
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+8
-3
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-0
No files found.
python/sglang/srt/configs/update_config.py
View file @
48c1fa7b
...
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
...
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
model_config
=
update_intermediate_size
(
model_config
=
update_intermediate_size
(
model_config
,
"intermediate_size"
,
intermediate_padding_size
model_config
,
"intermediate_size"
,
intermediate_padding_size
)
)
model_config
=
update_intermediate_size
(
model_config
,
"intermediate_size_mlp"
,
intermediate_padding_size
)
return
model_config
return
model_config
python/sglang/srt/layers/moe/topk.py
View file @
48c1fa7b
...
@@ -93,6 +93,19 @@ def fused_topk_cpu(
...
@@ -93,6 +93,19 @@ def fused_topk_cpu(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
apply_topk_weights_cpu
(
need_apply
,
topk_weights
,
inputs
):
if
not
need_apply
:
return
inputs
,
topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs
=
inputs
*
topk_weights
.
to
(
inputs
.
dtype
)
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
)
# clear topk_weights as already applied
return
inputs
,
topk_weights
def
fused_topk
(
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
48c1fa7b
...
@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
if
use_intel_amx_backend
(
layer
):
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
48c1fa7b
...
@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
if
use_intel_amx_backend
(
layer
)
and
not
apply_router_weight_on_input
:
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
(
select_experts
,
apply_topk_weights_cpu
,
)
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
48c1fa7b
...
@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
)
)
if
use_intel_amx_backend
(
layer
):
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment