Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f79f7778
Unverified
Commit
f79f7778
authored
Feb 04, 2026
by
Kunshang Ji
Committed by
GitHub
Feb 04, 2026
Browse files
[XPU][2/N] add support unquantized moe support for xpu (#33659)
Signed-off-by:
Kunshang Ji
<
kunshang.ji@intel.com
>
parent
4c8d1bf3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
34 deletions
+139
-34
.buildkite/scripts/hardware_ci/run-xpu-test.sh
.buildkite/scripts/hardware_ci/run-xpu-test.sh
+2
-0
requirements/xpu.txt
requirements/xpu.txt
+1
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+4
-0
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+10
-1
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+2
-32
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
+120
-0
No files found.
.buildkite/scripts/hardware_ci/run-xpu-test.sh
View file @
f79f7778
...
@@ -39,6 +39,8 @@ docker run \
...
@@ -39,6 +39,8 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
cd tests
cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py
pytest -v -s v1/engine
pytest -v -s v1/engine
...
...
requirements/xpu.txt
View file @
f79f7778
...
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
...
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
torchaudio
torchaudio
torchvision
torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.0/vllm_xpu_kernels-0.1.0-cp312-cp312-linux_x86_64.whl
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl
\ No newline at end of file
\ No newline at end of file
vllm/model_executor/layers/fused_moe/__init__.py
View file @
f79f7778
...
@@ -100,6 +100,9 @@ if HAS_TRITON:
...
@@ -100,6 +100,9 @@ if HAS_TRITON:
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
,
)
)
from
vllm.model_executor.layers.fused_moe.xpu_fused_moe
import
(
XPUExperts
,
)
__all__
+=
[
__all__
+=
[
"AiterExperts"
,
"AiterExperts"
,
...
@@ -117,6 +120,7 @@ if HAS_TRITON:
...
@@ -117,6 +120,7 @@ if HAS_TRITON:
"DeepGemmExperts"
,
"DeepGemmExperts"
,
"BatchedDeepGemmExperts"
,
"BatchedDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"XPUExperts"
,
]
]
else
:
else
:
# Some model classes directly use the custom ops. Add placeholders
# Some model classes directly use the custom ops. Add placeholders
...
...
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
View file @
f79f7778
...
@@ -46,7 +46,6 @@ class UnquantizedMoeBackend(Enum):
...
@@ -46,7 +46,6 @@ class UnquantizedMoeBackend(Enum):
UNSUPPORTED_BACKEND
=
[
UNSUPPORTED_BACKEND
=
[
UnquantizedMoeBackend
.
FLASHINFER_TRTLLM
,
UnquantizedMoeBackend
.
FLASHINFER_TRTLLM
,
UnquantizedMoeBackend
.
CPU
,
UnquantizedMoeBackend
.
CPU
,
UnquantizedMoeBackend
.
XPU
,
UnquantizedMoeBackend
.
TPU
,
UnquantizedMoeBackend
.
TPU
,
UnquantizedMoeBackend
.
OOT
,
UnquantizedMoeBackend
.
OOT
,
]
]
...
@@ -196,4 +195,14 @@ def make_unquantized_moe_kernel(
...
@@ -196,4 +195,14 @@ def make_unquantized_moe_kernel(
quant_config
=
quant_config
,
quant_config
=
quant_config
,
),
),
)
)
elif
backend
==
UnquantizedMoeBackend
.
XPU
:
from
vllm.model_executor.layers.fused_moe
import
XPUExperts
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
XPUExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
return
kernel
,
use_inplace
return
kernel
,
use_inplace
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
f79f7778
...
@@ -40,7 +40,7 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs
...
@@ -40,7 +40,7 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_xpu
()
:
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_moe
import
TritonExperts
from
.fused_moe
import
TritonExperts
else
:
else
:
...
@@ -71,7 +71,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -71,7 +71,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
self
.
_is_monolithic
=
(
self
.
_is_monolithic
=
(
current_platform
.
is_cpu
()
current_platform
.
is_cpu
()
or
current_platform
.
is_xpu
()
or
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
FLASHINFER_TRTLLM
or
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
FLASHINFER_TRTLLM
)
)
...
@@ -82,8 +81,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -82,8 +81,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""Select the monolithic implementation based on platform."""
"""Select the monolithic implementation based on platform."""
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
():
return
self
.
forward_monolithic_cpu
return
self
.
forward_monolithic_cpu
elif
current_platform
.
is_xpu
():
return
self
.
forward_monolithic_xpu
else
:
else
:
return
self
.
forward_monolithic_cuda
return
self
.
forward_monolithic_cuda
...
@@ -256,16 +253,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -256,16 +253,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
)
layer
.
w13_weight
=
Parameter
(
w13_weights_shuffled
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w13_weights_shuffled
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weights_shuffled
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weights_shuffled
,
requires_grad
=
False
)
elif
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
XPU
:
import
intel_extension_for_pytorch
as
ipex
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
self
.
ipex_fusion
=
ipex
.
llm
.
modules
.
GatedMLPMOE
(
layer
.
w13_weight
,
layer
.
w2_weight
,
use_prepack
=
True
,
experts_start_id
=
ep_rank_start
,
)
elif
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
CPU
:
elif
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
CPU
:
from
vllm.model_executor.layers.fused_moe
import
cpu_fused_moe
from
vllm.model_executor.layers.fused_moe
import
cpu_fused_moe
...
@@ -297,7 +284,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -297,7 +284,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
self
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
else
:
else
:
self
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
self
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
elif
current_platform
.
is_cuda_alike
():
elif
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_xpu
()
:
self
.
_setup_kernel
(
self
.
_setup_kernel
(
layer
=
layer
,
layer
=
layer
,
w13
=
layer
.
w13_weight
,
w13
=
layer
.
w13_weight
,
...
@@ -399,20 +386,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -399,20 +386,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
apply_router_weight_on_input
,
layer
.
apply_router_weight_on_input
,
layer
.
activation
,
layer
.
activation
,
)
)
def
forward_monolithic_xpu
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
ipex_fusion
(
x
,
layer
.
use_grouped_topk
,
layer
.
top_k
,
router_logits
,
layer
.
renormalize
,
layer
.
topk_group
,
layer
.
num_expert_group
,
custom_routing_function
=
layer
.
custom_routing_function
,
)
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
0 → 100644
View file @
f79f7778
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
if
current_platform
.
is_xpu
():
from
vllm_xpu_kernels.fused_moe_interface
import
xpu_fused_moe
class
XPUExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
is_xpu
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# TODO: dispatch based on device.
SUPPORTED_W_A
=
[
(
None
,
None
),
(
kFp8StaticTensorSym
,
None
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
def
supports_chunking
(
self
)
->
bool
:
return
False
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
str
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
0
,)
workspace2
=
(
0
,)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
topk
=
topk_ids
.
size
(
-
1
)
xpu_fused_moe
(
hidden_states
=
hidden_states
,
w13
=
w1
,
w13_scales
=
a1q_scale
,
w13_bias
=
self
.
w1_bias
,
w2
=
w2
,
w2_scales
=
a2_scale
,
w2_bias
=
self
.
w2_bias
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
n_experts_per_token
=
topk
,
activation
=
activation
,
num_experts
=
self
.
moe_config
.
num_local_experts
,
ep_rank
=
self
.
moe_config
.
ep_rank
,
ep_size
=
self
.
moe_config
.
ep_size
,
output
=
output
,
)
return
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment