Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05339a7b
Unverified
Commit
05339a7b
authored
Feb 11, 2026
by
Li, Jiang
Committed by
GitHub
Feb 11, 2026
Browse files
[Bugfix][CPU] Fix llama4 inference on CPU (#34321)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
40b8f553
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
18 deletions
+60
-18
.gitignore
.gitignore
+3
-0
csrc/cpu/cpu_fused_moe.cpp
csrc/cpu/cpu_fused_moe.cpp
+10
-3
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+3
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-0
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+28
-8
vllm/v1/worker/cpu_worker.py
vllm/v1/worker/cpu_worker.py
+14
-5
No files found.
.gitignore
View file @
05339a7b
...
@@ -238,3 +238,6 @@ ep_kernels_workspace/
...
@@ -238,3 +238,6 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
vllm/grpc/vllm_engine_pb2.pyi
# Ignore generated cpu headers
csrc/cpu/cpu_attn_dispatch_generated.h
csrc/cpu/cpu_fused_moe.cpp
View file @
05339a7b
...
@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
...
@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
const
int32_t
token_num
,
const
int32_t
expert_num
,
const
int32_t
token_num
,
const
int32_t
expert_num
,
const
int32_t
topk_num
,
const
int32_t
input_size_13
,
const
int32_t
topk_num
,
const
int32_t
input_size_13
,
const
int32_t
output_size_13
,
const
int32_t
input_size_2
,
const
int32_t
output_size_13
,
const
int32_t
input_size_2
,
const
int32_t
output_size_2
)
{
const
int32_t
output_size_2
,
const
bool
skip_weighted
)
{
using
scalar_vec_t
=
typename
cpu_utils
::
VecTypeTrait
<
scalar_t
>::
vec_t
;
using
scalar_vec_t
=
typename
cpu_utils
::
VecTypeTrait
<
scalar_t
>::
vec_t
;
constexpr
int32_t
gemm_n_tile_size
=
gemm_t
::
NSize
;
constexpr
int32_t
gemm_n_tile_size
=
gemm_t
::
NSize
;
constexpr
int32_t
gemm_m_tile_size
=
gemm_t
::
MaxMSize
;
constexpr
int32_t
gemm_m_tile_size
=
gemm_t
::
MaxMSize
;
...
@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
...
@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
scalar_t
*
__restrict__
curr_output_buffer
=
scalar_t
*
__restrict__
curr_output_buffer
=
output
+
token_id
*
output_size_2
;
output
+
token_id
*
output_size_2
;
if
(
skip_weighted
)
{
// Only for topk_num == 1
*
curr_weight
=
1.0
f
;
}
if
(
topk_num
>
1
)
{
if
(
topk_num
>
1
)
{
{
{
int32_t
w2_output_idx
=
curr_expand_token_id_index_buffer
[
0
];
int32_t
w2_output_idx
=
curr_expand_token_id_index_buffer
[
0
];
...
@@ -699,7 +704,7 @@ void cpu_fused_moe(
...
@@ -699,7 +704,7 @@ void cpu_fused_moe(
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
// [expert_num, output_size_2]
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
// [expert_num, output_size_2]
const
torch
::
Tensor
&
topk_weights
,
// [token_num, k], float32
const
torch
::
Tensor
&
topk_weights
,
// [token_num, k], float32
const
torch
::
Tensor
&
topk_id
,
// [token_num, k], int32
const
torch
::
Tensor
&
topk_id
,
// [token_num, k], int32
const
std
::
string
&
act
,
const
std
::
string
&
isa
)
{
const
bool
skip_weighted
,
const
std
::
string
&
act
,
const
std
::
string
&
isa
)
{
const
int32_t
token_num
=
input
.
size
(
0
);
const
int32_t
token_num
=
input
.
size
(
0
);
const
int32_t
input_size_13
=
input
.
size
(
1
);
const
int32_t
input_size_13
=
input
.
size
(
1
);
const
int64_t
input_stride
=
input
.
stride
(
0
);
const
int64_t
input_stride
=
input
.
stride
(
0
);
...
@@ -711,6 +716,8 @@ void cpu_fused_moe(
...
@@ -711,6 +716,8 @@ void cpu_fused_moe(
const
int32_t
topk_num
=
topk_id
.
size
(
1
);
const
int32_t
topk_num
=
topk_id
.
size
(
1
);
const
FusedMOEAct
act_type
=
get_act_type
(
act
);
const
FusedMOEAct
act_type
=
get_act_type
(
act
);
cpu_utils
::
ISA
isa_type
=
cpu_utils
::
get_isa
(
isa
);
cpu_utils
::
ISA
isa_type
=
cpu_utils
::
get_isa
(
isa
);
TORCH_CHECK
(
!
skip_weighted
||
topk_num
==
1
,
"skip_weighted is only supported for topk=1 on CPU"
);
VLLM_DISPATCH_FLOATING_TYPES
(
w13
.
scalar_type
(),
"cpu_fused_moe"
,
[
&
]()
{
VLLM_DISPATCH_FLOATING_TYPES
(
w13
.
scalar_type
(),
"cpu_fused_moe"
,
[
&
]()
{
CPU_ISA_DISPATCH_IMPL
(
isa_type
,
[
&
]()
{
CPU_ISA_DISPATCH_IMPL
(
isa_type
,
[
&
]()
{
...
@@ -721,7 +728,7 @@ void cpu_fused_moe(
...
@@ -721,7 +728,7 @@ void cpu_fused_moe(
w2_bias
.
has_value
()
?
w2_bias
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
w2_bias
.
has_value
()
?
w2_bias
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
topk_weights
.
data_ptr
<
float
>
(),
topk_id
.
data_ptr
<
int32_t
>
(),
act_type
,
topk_weights
.
data_ptr
<
float
>
(),
topk_id
.
data_ptr
<
int32_t
>
(),
act_type
,
token_num
,
expert_num
,
topk_num
,
input_size_13
,
output_size_13
,
token_num
,
expert_num
,
topk_num
,
input_size_13
,
output_size_13
,
input_size_2
,
output_size_2
);
input_size_2
,
output_size_2
,
skip_weighted
);
});
});
});
});
}
}
csrc/cpu/torch_bindings.cpp
View file @
05339a7b
...
@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
...
@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const
std
::
optional
<
torch
::
Tensor
>&
w13_bias
,
const
std
::
optional
<
torch
::
Tensor
>&
w13_bias
,
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_id
,
const
std
::
string
&
act
,
const
torch
::
Tensor
&
topk_id
,
const
bool
skip_weighted
,
const
std
::
string
&
isa
);
const
std
::
string
&
act
,
const
std
::
string
&
isa
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// vLLM custom ops
...
@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
ops
.
def
(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"bool skip_weighted, "
"str act, str isa) -> ()"
);
"str act, str isa) -> ()"
);
ops
.
impl
(
"cpu_fused_moe"
,
torch
::
kCPU
,
&
cpu_fused_moe
);
ops
.
impl
(
"cpu_fused_moe"
,
torch
::
kCPU
,
&
cpu_fused_moe
);
#endif
#endif
...
...
vllm/_custom_ops.py
View file @
05339a7b
...
@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
...
@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
act
:
str
,
act
:
str
,
isa
:
str
,
isa
:
str
,
skip_weighted
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
input
)
output
=
torch
.
empty_like
(
input
)
torch
.
ops
.
_C
.
cpu_fused_moe
(
torch
.
ops
.
_C
.
cpu_fused_moe
(
...
@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
...
@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
w2_bias
,
w2_bias
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
skip_weighted
,
act
,
act
,
isa
,
isa
,
)
)
...
...
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
View file @
05339a7b
...
@@ -238,7 +238,6 @@ class CPUFusedMOE:
...
@@ -238,7 +238,6 @@ class CPUFusedMOE:
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
in
_CPU_MOE_ACT_FN
,
f
"
{
activation
}
is not supported."
assert
activation
in
_CPU_MOE_ACT_FN
,
f
"
{
activation
}
is not supported."
assert
not
apply_router_weight_on_input
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -261,6 +260,7 @@ class CPUFusedMOE:
...
@@ -261,6 +260,7 @@ class CPUFusedMOE:
topk_ids
,
topk_ids
,
activation
,
activation
,
global_num_experts
,
global_num_experts
,
apply_router_weight_on_input
,
)
)
def
check_grouped_gemm
(
def
check_grouped_gemm
(
...
@@ -355,7 +355,14 @@ class CPUFusedMOE:
...
@@ -355,7 +355,14 @@ class CPUFusedMOE:
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
activation
:
str
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
skip_weighted
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
skip_weighted
:
assert
topk_ids
.
size
(
1
)
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
input
.
mul_
(
topk_weights
.
to
(
input
.
dtype
))
output
=
cpu_fused_moe
(
output
=
cpu_fused_moe
(
input
,
input
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -366,6 +373,7 @@ class CPUFusedMOE:
...
@@ -366,6 +373,7 @@ class CPUFusedMOE:
topk_ids
,
topk_ids
,
activation
,
activation
,
self
.
isa
,
self
.
isa
,
skip_weighted
,
)
)
return
output
return
output
...
@@ -377,7 +385,14 @@ class CPUFusedMOE:
...
@@ -377,7 +385,14 @@ class CPUFusedMOE:
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
activation
:
str
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
skip_weighted
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
skip_weighted
:
assert
topk_ids
.
size
(
1
)
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
input
.
mul_
(
topk_weights
.
to
(
input
.
dtype
))
output
=
torch
.
empty_like
(
input
)
output
=
torch
.
empty_like
(
input
)
layer_id
=
id
(
layer
)
layer_id
=
id
(
layer
)
torch
.
ops
.
vllm
.
cpu_fused_moe_torch
(
torch
.
ops
.
vllm
.
cpu_fused_moe_torch
(
...
@@ -388,6 +403,7 @@ class CPUFusedMOE:
...
@@ -388,6 +403,7 @@ class CPUFusedMOE:
topk_ids
,
topk_ids
,
activation
,
activation
,
global_num_experts
,
global_num_experts
,
skip_weighted
,
)
)
return
output
return
output
...
@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
...
@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
activation
:
str
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
skip_weighted
:
bool
=
False
,
)
->
None
:
)
->
None
:
layer
=
_CPU_MOE_LAYER_CACHE
[
layer_id
]()
layer
=
_CPU_MOE_LAYER_CACHE
[
layer_id
]()
...
@@ -434,6 +451,9 @@ def cpu_fused_moe_torch(
...
@@ -434,6 +451,9 @@ def cpu_fused_moe_torch(
new_x
=
torch
.
empty_like
(
outs
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
new_x
[
idxs
]
=
outs
if
skip_weighted
:
final_out
=
new_x
else
:
final_out
=
(
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weights
.
dtype
)
.
type
(
topk_weights
.
dtype
)
...
...
vllm/v1/worker/cpu_worker.py
View file @
05339a7b
...
@@ -160,12 +160,21 @@ class CPUWorker(Worker):
...
@@ -160,12 +160,21 @@ class CPUWorker(Worker):
x
for
x
in
logical_cpu_list
if
x
.
numa_node
==
selected_numa_node
x
for
x
in
logical_cpu_list
if
x
.
numa_node
==
selected_numa_node
]
]
else
:
else
:
assert
len
(
logical_cpu_list
)
>=
self
.
parallel_config
.
world_size
# This is a bit tricky because the internal DP size
logical_cpu_list
=
sorted
(
logical_cpu_list
,
key
=
lambda
x
:
x
.
numa_node
)
# is always 1 for non-MoE models
sim_cpu_num_per_node
=
(
world_size_across_dp
=
(
len
(
logical_cpu_list
)
//
self
.
parallel_config
.
world_size
self
.
parallel_config
.
world_size
*
self
.
parallel_config
.
_api_process_count
)
)
start_idx
=
self
.
local_rank
*
sim_cpu_num_per_node
assert
len
(
logical_cpu_list
)
>=
world_size_across_dp
logical_cpu_list
=
sorted
(
logical_cpu_list
,
key
=
lambda
x
:
x
.
numa_node
)
sim_cpu_num_per_node
=
len
(
logical_cpu_list
)
//
world_size_across_dp
assert
self
.
parallel_config
.
data_parallel_rank_local
is
not
None
start_idx
=
(
self
.
local_rank
+
self
.
parallel_config
.
world_size
*
self
.
parallel_config
.
data_parallel_rank_local
)
*
sim_cpu_num_per_node
logical_cpu_list
=
logical_cpu_list
[
logical_cpu_list
=
logical_cpu_list
[
start_idx
:
(
start_idx
+
sim_cpu_num_per_node
)
start_idx
:
(
start_idx
+
sim_cpu_num_per_node
)
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment