Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
debd6e76
Unverified
Commit
debd6e76
authored
Mar 23, 2026
by
Kunshang Ji
Committed by
GitHub
Mar 23, 2026
Browse files
[XPU][MoE Refactor] Refactor xpu mxfp4 support into oracle (#37784)
Signed-off-by:
Kunshang Ji
<
kunshang.ji@intel.com
>
parent
9ace378a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
101 deletions
+54
-101
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+23
-2
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
+30
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+1
-99
No files found.
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
View file @
debd6e76
...
@@ -141,7 +141,10 @@ def backend_to_kernel_cls(
...
@@ -141,7 +141,10 @@ def backend_to_kernel_cls(
return
[
AiterExperts
]
return
[
AiterExperts
]
elif
backend
==
Mxfp4MoeBackend
.
XPU
:
elif
backend
==
Mxfp4MoeBackend
.
XPU
:
raise
NotImplementedError
(
"XPU backend uses XpuMxfp4MoEMethod directly."
)
from
vllm.model_executor.layers.fused_moe.xpu_fused_moe
import
XPUExpertsMXFp4
return
[
XPUExpertsMXFp4
]
else
:
else
:
raise
ValueError
(
f
"Unknown MXFP4 MoE backend:
{
backend
.
value
}
"
)
raise
ValueError
(
f
"Unknown MXFP4 MoE backend:
{
backend
.
value
}
"
)
...
@@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
...
@@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"triton"
:
Mxfp4MoeBackend
.
TRITON
,
"triton"
:
Mxfp4MoeBackend
.
TRITON
,
"marlin"
:
Mxfp4MoeBackend
.
MARLIN
,
"marlin"
:
Mxfp4MoeBackend
.
MARLIN
,
"ck"
:
Mxfp4MoeBackend
.
CK
,
"ck"
:
Mxfp4MoeBackend
.
CK
,
"xpu"
:
Mxfp4MoeBackend
.
XPU
,
}
}
if
backend
:
=
mapping
.
get
(
runner_backend
):
if
backend
:
=
mapping
.
get
(
runner_backend
):
return
backend
return
backend
...
@@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
...
@@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
,
Mxfp4MoeBackend
.
XPU
,
]
]
return
_AVAILABLE_BACKENDS
return
_AVAILABLE_BACKENDS
...
@@ -351,7 +356,13 @@ def select_mxfp4_moe_backend(
...
@@ -351,7 +356,13 @@ def select_mxfp4_moe_backend(
if
current_platform
.
is_xpu
():
if
current_platform
.
is_xpu
():
backend
=
Mxfp4MoeBackend
.
XPU
backend
=
Mxfp4MoeBackend
.
XPU
logger
.
info_once
(
_make_log_backend
(
backend
))
logger
.
info_once
(
_make_log_backend
(
backend
))
return
backend
,
None
return
_return_or_raise
(
Mxfp4MoeBackend
.
XPU
,
config
,
kMxfp4Static
,
None
,
activation_format
,
)
if
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
():
if
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
():
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -741,6 +752,16 @@ def convert_to_mxfp4_moe_kernel_format(
...
@@ -741,6 +752,16 @@ def convert_to_mxfp4_moe_kernel_format(
w13_bias
,
w13_bias
,
w2_bias
,
w2_bias
,
)
)
elif
mxfp4_backend
==
Mxfp4MoeBackend
.
XPU
:
# No additional transformation needed for XPU backend
return
(
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
w13_bias
,
w2_bias
,
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported mxfp4_backend:
{
mxfp4_backend
}
: "
f
"Unsupported mxfp4_backend:
{
mxfp4_backend
}
: "
...
...
vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
View file @
debd6e76
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey
,
QuantKey
,
kFp8DynamicTensorSym
,
kFp8DynamicTensorSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kMxfp4Static
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -38,6 +39,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
...
@@ -38,6 +39,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
num_dispatchers
,
num_dispatchers
,
)
)
self
.
is_fp8
=
False
self
.
is_fp8
=
False
self
.
is_mxfp4
=
False
@
property
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
def
expects_unquantized_inputs
(
self
)
->
bool
:
...
@@ -137,6 +139,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
...
@@ -137,6 +139,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
ep_size
=
self
.
moe_config
.
ep_size
,
ep_size
=
self
.
moe_config
.
ep_size
,
output
=
output
,
output
=
output
,
is_fp8
=
self
.
is_fp8
,
is_fp8
=
self
.
is_fp8
,
is_mxfp4
=
self
.
is_mxfp4
,
)
)
...
@@ -155,3 +158,30 @@ class XPUExpertsFp8(XPUExperts):
...
@@ -155,3 +158,30 @@ class XPUExpertsFp8(XPUExperts):
num_dispatchers
,
num_dispatchers
,
)
)
self
.
is_fp8
=
True
self
.
is_fp8
=
True
class
XPUExpertsMXFp4
(
XPUExperts
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
):
super
().
__init__
(
moe_config
,
quant_config
,
max_num_tokens
,
num_dispatchers
,
)
self
.
is_mxfp4
=
True
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kMxfp4Static
,
None
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
vllm/model_executor/layers/quantization/mxfp4.py
View file @
debd6e76
...
@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoE
,
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
MoEActivation
,
)
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -33,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -33,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import (
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -80,10 +78,7 @@ class Mxfp4Config(QuantizationConfig):
...
@@ -80,10 +78,7 @@ class Mxfp4Config(QuantizationConfig):
)
)
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
current_platform
.
is_xpu
():
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
return
XpuMxfp4MoEMethod
(
layer
.
moe_config
)
else
:
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
elif
isinstance
(
layer
,
Attention
):
elif
isinstance
(
layer
,
Attention
):
logger
.
debug_once
(
logger
.
debug_once
(
"MXFP4 attention layer is not implemented. "
"MXFP4 attention layer is not implemented. "
...
@@ -420,96 +415,3 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -420,96 +415,3 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_map
=
layer
.
expert_map
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
)
class
XpuMxfp4MoEMethod
(
Mxfp4MoEMethod
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
):
super
().
__init__
(
moe_config
)
self
.
moe_config
=
moe_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
super
().
create_weights
(
layer
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
params_dtype
,
**
extra_weight_attrs
,
)
self
.
original_hidden_size
=
hidden_size
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
@
property
def
is_monolithic
(
self
)
->
bool
:
return
True
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
assert
layer
.
activation
==
MoEActivation
.
SWIGLUOAI
,
(
"Only swiglu_oai activation is supported for "
f
"XPU MXFP4 MoE, not
{
layer
.
activation
}
."
)
from
vllm_xpu_kernels.fused_moe_interface
import
xpu_fused_moe
M
,
_
=
x
.
size
()
routing_weights
=
torch
.
empty
(
M
,
layer
.
top_k
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
selected_experts
=
torch
.
empty
(
M
,
layer
.
top_k
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
token_expert_indices
=
torch
.
empty
(
M
,
layer
.
top_k
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
if
layer
.
use_grouped_topk
:
routing_weights
,
selected_experts
=
torch
.
ops
.
_moe_C
.
fused_grouped_topk
(
x
,
router_logits
,
layer
.
top_k
,
layer
.
renormalize
,
n_expert_group
=
layer
.
num_expert_group
,
n_topk_group
=
layer
.
topk_group
,
scoring_func
=
layer
.
scoring_func
,
routed_scaling_factor
=
layer
.
routed_scaling_factor
,
bias
=
layer
.
e_score_correction_bias
,
)
else
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
routing_weights
,
selected_experts
,
token_expert_indices
,
router_logits
,
layer
.
renormalize
,
layer
.
e_score_correction_bias
,
)
return
xpu_fused_moe
(
hidden_states
=
x
,
w13
=
layer
.
w13_weight
,
w13_bias
=
layer
.
w13_bias
if
self
.
moe
.
has_bias
else
None
,
w13_scales
=
layer
.
w13_weight_scale
,
w2
=
layer
.
w2_weight
,
w2_bias
=
layer
.
w2_bias
if
self
.
moe
.
has_bias
else
None
,
w2_scales
=
layer
.
w2_weight_scale
,
topk_weights
=
routing_weights
,
topk_ids
=
selected_experts
,
n_experts_per_token
=
layer
.
top_k
,
activation
=
layer
.
activation
.
value
,
num_experts
=
layer
.
local_num_experts
,
is_mxfp4
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment