Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b482f71e
Unverified
Commit
b482f71e
authored
Feb 11, 2026
by
zofia
Committed by
GitHub
Feb 11, 2026
Browse files
[XPU][7/N] enable xpu fp8 moe (#34202)
Signed-off-by:
Zhu, Zufang
<
zufang.zhu@intel.com
>
parent
1485396a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
5 deletions
+52
-5
requirements/xpu.txt
requirements/xpu.txt
+1
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-0
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+10
-0
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
+39
-4
No files found.
requirements/xpu.txt
View file @
b482f71e
...
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
...
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
torchaudio
torchaudio
torchvision
torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl
\ No newline at end of file
vllm/model_executor/layers/fused_moe/__init__.py
View file @
b482f71e
...
@@ -102,6 +102,7 @@ if HAS_TRITON:
...
@@ -102,6 +102,7 @@ if HAS_TRITON:
)
)
from
vllm.model_executor.layers.fused_moe.xpu_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.xpu_fused_moe
import
(
XPUExperts
,
XPUExperts
,
XPUExpertsFp8
,
)
)
__all__
+=
[
__all__
+=
[
...
@@ -121,6 +122,7 @@ if HAS_TRITON:
...
@@ -121,6 +122,7 @@ if HAS_TRITON:
"BatchedDeepGemmExperts"
,
"BatchedDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"XPUExperts"
,
"XPUExperts"
,
"XPUExpertsFp8"
,
]
]
else
:
else
:
# Some model classes directly use the custom ops. Add placeholders
# Some model classes directly use the custom ops. Add placeholders
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
b482f71e
...
@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
...
@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
AITER
=
"AITER"
AITER
=
"AITER"
VLLM_CUTLASS
=
"VLLM_CUTLASS"
VLLM_CUTLASS
=
"VLLM_CUTLASS"
BATCHED_VLLM_CUTLASS
=
"BATCHED_VLLM_CUTLASS"
BATCHED_VLLM_CUTLASS
=
"BATCHED_VLLM_CUTLASS"
XPU
=
"XPU"
def
backend_to_kernel_cls
(
def
backend_to_kernel_cls
(
...
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
...
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
return
CutlassBatchedExpertsFp8
return
CutlassBatchedExpertsFp8
elif
backend
==
Fp8MoeBackend
.
XPU
:
from
vllm.model_executor.layers.fused_moe.xpu_fused_moe
import
(
XPUExpertsFp8
,
)
return
XPUExpertsFp8
else
:
else
:
raise
ValueError
(
f
"Unknown FP8 MoE backend:
{
backend
.
value
}
"
)
raise
ValueError
(
f
"Unknown FP8 MoE backend:
{
backend
.
value
}
"
)
...
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
...
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
Fp8MoeBackend
.
TRITON
,
Fp8MoeBackend
.
TRITON
,
Fp8MoeBackend
.
BATCHED_TRITON
,
Fp8MoeBackend
.
BATCHED_TRITON
,
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
XPU
,
]
]
# NOTE(rob): We need to peak into the P/F selection to determine
# NOTE(rob): We need to peak into the P/F selection to determine
...
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
...
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend
.
BATCHED_TRITON
,
Fp8MoeBackend
.
BATCHED_TRITON
,
Fp8MoeBackend
.
VLLM_CUTLASS
,
Fp8MoeBackend
.
VLLM_CUTLASS
,
Fp8MoeBackend
.
BATCHED_VLLM_CUTLASS
,
Fp8MoeBackend
.
BATCHED_VLLM_CUTLASS
,
Fp8MoeBackend
.
XPU
,
]:
]:
raise
ValueError
(
f
"Unsupported FP8 MoE backend:
{
fp8_backend
.
value
}
"
)
raise
ValueError
(
f
"Unsupported FP8 MoE backend:
{
fp8_backend
.
value
}
"
)
...
...
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
View file @
b482f71e
...
@@ -4,13 +4,16 @@ import torch
...
@@ -4,13 +4,16 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
TopKWeightAndReduceNoOP
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
QuantKey
,
kFp8DynamicTensorSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
...
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
class
XPUExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
class
XPUExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
):
super
().
__init__
(
moe_config
,
quant_config
,
max_num_tokens
,
num_dispatchers
,
)
self
.
is_fp8
=
False
@
property
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
return
True
...
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key
:
QuantKey
|
None
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
)
->
bool
:
# TODO: dispatch based on device.
SUPPORTED_W_A
=
[
SUPPORTED_W_A
=
[
(
None
,
None
),
(
None
,
None
),
(
kFp8StaticTensorSym
,
None
),
(
kFp8StaticTensorSym
,
None
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
]
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
...
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
xpu_fused_moe
(
xpu_fused_moe
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
w13
=
w1
,
w13
=
w1
,
w13_scales
=
a1q
_scale
,
w13_scales
=
self
.
w1
_scale
,
w13_bias
=
self
.
w1_bias
,
w13_bias
=
self
.
w1_bias
,
w2
=
w2
,
w2
=
w2
,
w2_scales
=
a
2_scale
,
w2_scales
=
self
.
w
2_scale
,
w2_bias
=
self
.
w2_bias
,
w2_bias
=
self
.
w2_bias
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
...
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank
=
self
.
moe_config
.
ep_rank
,
ep_rank
=
self
.
moe_config
.
ep_rank
,
ep_size
=
self
.
moe_config
.
ep_size
,
ep_size
=
self
.
moe_config
.
ep_size
,
output
=
output
,
output
=
output
,
is_fp8
=
self
.
is_fp8
,
)
class
XPUExpertsFp8
(
XPUExperts
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
):
super
().
__init__
(
moe_config
,
quant_config
,
max_num_tokens
,
num_dispatchers
,
)
)
return
self
.
is_fp8
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment