Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18b39828
Unverified
Commit
18b39828
authored
Nov 05, 2025
by
Kunshang Ji
Committed by
GitHub
Nov 05, 2025
Browse files
[XPU] Add gpt-oss model support for Intel GPU (#27786)
Signed-off-by:
Kunshang Ji
<
kunshang.ji@intel.com
>
parent
4ea62b77
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
6 deletions
+101
-6
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+7
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+92
-2
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+0
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-1
No files found.
vllm/attention/utils/fa_utils.py
View file @
18b39828
...
...
@@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
)
def
flash_attn_supports_sinks
()
->
bool
:
if
current_platform
.
is_xpu
():
return
True
else
:
return
get_flash_attn_version
()
==
3
def
flash_attn_supports_mla
():
from
vllm.platforms
import
current_platform
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
18b39828
...
...
@@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
else
:
logger
.
info_once
(
"Using Triton backend"
)
return
Mxfp4Backend
.
TRITON
elif
current_platform
.
is_xpu
():
logger
.
info_once
(
"Using ipex marlin backend on XPU"
)
return
Mxfp4Backend
.
MARLIN
elif
current_platform
.
is_rocm
()
and
has_triton_kernels
():
logger
.
info_once
(
"Using Triton backend"
)
return
Mxfp4Backend
.
TRITON
...
...
@@ -188,7 +191,10 @@ class Mxfp4Config(QuantizationConfig):
return
UnquantizedLinearMethod
()
raise
NotImplementedError
(
"Mxfp4 linear layer is not implemented"
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
if
current_platform
.
is_xpu
():
return
IpexMxfp4MoEMethod
(
layer
.
moe_config
)
else
:
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
elif
isinstance
(
layer
,
Attention
):
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
return
None
...
...
@@ -245,7 +251,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition
,
128
)
hidden_size
=
round_up
(
hidden_size
,
256
)
if
current_platform
.
is_xpu
():
hidden_size
=
round_up
(
hidden_size
,
128
)
else
:
hidden_size
=
round_up
(
hidden_size
,
256
)
layer
.
params_dtype
=
params_dtype
layer
.
num_experts
=
num_experts
...
...
@@ -1071,3 +1080,84 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
else
:
raise
ValueError
(
f
"Unsupported backend:
{
self
.
mxfp4_backend
}
"
)
class
IpexMxfp4MoEMethod
(
Mxfp4MoEMethod
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
):
super
().
__init__
(
moe_config
)
self
.
moe_config
=
moe_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
super
().
create_weights
(
layer
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
params_dtype
,
**
extra_weight_attrs
,
)
self
.
original_hidden_size
=
hidden_size
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
import
intel_extension_for_pytorch
as
ipex
layer
.
w13_weight
.
data
=
layer
.
w13_weight
.
data
.
view
(
torch
.
int32
)
layer
.
w2_weight
.
data
=
layer
.
w2_weight
.
data
.
view
(
torch
.
int32
)
layer
.
ipex_fusion
=
ipex
.
llm
.
modules
.
GatedMLPMOE
(
layer
.
w13_weight
,
layer
.
w2_weight
,
w1_scale_inv
=
layer
.
w13_weight_scale
,
w2_scale_inv
=
layer
.
w2_weight_scale
,
w13_bias
=
layer
.
w13_bias
,
w2_bias
=
layer
.
w2_bias
,
is_mxfp4
=
True
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
activation
==
"swigluoai"
,
(
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
)
# noqa:
hidden_size_pad
=
round_up
(
self
.
original_hidden_size
,
128
)
x_pad
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
hidden_size_pad
-
x
.
size
(
-
1
)))
hidden_states
=
layer
.
ipex_fusion
(
x_pad
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
activation
=
"swiglu_oai"
,
)
hidden_states
=
hidden_states
[...,
:
self
.
original_hidden_size
].
contiguous
()
return
hidden_states
vllm/model_executor/models/gpt_oss.py
View file @
18b39828
...
...
@@ -337,9 +337,6 @@ class GptOssModel(nn.Module):
if
is_pp_missing_parameter
(
name
,
self
):
continue
# FIXME(woosuk): Remove this after testing.
weight
=
weight
.
cuda
()
if
".w13_weight_scale"
in
name
:
# Handle MLP gate and up projection weights scale
if
use_ep
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
18b39828
...
...
@@ -27,6 +27,7 @@ from vllm.attention.utils.fa_utils import (
if
is_flash_attn_varlen_func_available
():
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_sinks
,
flash_attn_varlen_func
,
get_scheduler_metadata
,
reshape_and_cache_flash
,
...
...
@@ -497,7 +498,7 @@ class FlashAttentionImpl(AttentionImpl):
self
.
sinks
=
sinks
if
self
.
sinks
is
not
None
:
assert
self
.
vllm_
flash_attn_
version
==
3
,
(
assert
flash_attn_
supports_sinks
()
,
(
"Sinks are only supported in FlashAttention 3"
)
assert
self
.
sinks
.
shape
[
0
]
==
num_heads
,
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment