Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98caeadd
Unverified
Commit
98caeadd
authored
Nov 25, 2025
by
Fadi Arafeh
Committed by
GitHub
Nov 25, 2025
Browse files
[fix][cpu] Use a SwigluOAI impl which supports interleaved gate-up wei (#29273)
Signed-off-by:
Fadi Arafeh
<
fadi.arafeh@arm.com
>
parent
64deead7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
21 deletions
+8
-21
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+8
-21
No files found.
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
View file @
98caeadd
...
...
@@ -6,22 +6,7 @@ import torch
from
torch.nn
import
functional
as
F
from
vllm
import
_custom_ops
as
ops
def
silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
swigluoai_and_mul
(
x
:
torch
.
Tensor
,
alpha
:
float
=
1.702
,
limit
:
float
=
7.0
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
gate
,
up
=
x
[...,
:
d
],
x
[...,
d
:]
gate
=
gate
.
clamp
(
max
=
limit
)
up
=
up
.
clamp
(
min
=-
limit
,
max
=
limit
)
glu
=
gate
*
torch
.
sigmoid
(
alpha
*
gate
)
return
(
up
+
1
)
*
glu
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
SwigluOAIAndMul
def
grouped_topk
(
...
...
@@ -227,6 +212,11 @@ class CPUFusedMOE:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
self
.
act_to_impl
=
{
"silu"
:
SiluAndMul
(),
"swigluoai"
:
SwigluOAIAndMul
(),
}
def
__call__
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -246,7 +236,7 @@ class CPUFusedMOE:
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
in
{
"silu"
,
"swigluoai"
}
,
f
"
{
activation
}
is not supported."
assert
activation
in
self
.
act_to_impl
,
f
"
{
activation
}
is not supported."
assert
not
apply_router_weight_on_input
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
...
...
@@ -283,10 +273,7 @@ class CPUFusedMOE:
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
gate_up
=
layer
.
gate_up_linear
[
i
](
tokens_for_this_expert
)
if
activation
==
"swigluoai"
:
gate_up
=
swigluoai_and_mul
(
gate_up
)
else
:
gate_up
=
silu_and_mul
(
gate_up
)
gate_up
=
self
.
act_to_impl
[
activation
].
forward_native
(
gate_up
)
expert_out
=
layer
.
down_linear
[
i
](
gate_up
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment