Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14771f71
Unverified
Commit
14771f71
authored
Mar 25, 2026
by
Kunshang Ji
Committed by
GitHub
Mar 25, 2026
Browse files
[XPU] support MLA model on Intel GPU (#37143)
Signed-off-by:
Kunshang Ji
<
kunshang.ji@intel.com
>
parent
189ddefb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
12 deletions
+15
-12
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+1
-0
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+4
-0
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+10
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+0
-12
No files found.
vllm/_xpu_ops.py
View file @
14771f71
...
@@ -170,6 +170,7 @@ class xpu_ops:
...
@@ -170,6 +170,7 @@ class xpu_ops:
num_splits
=
0
,
num_splits
=
0
,
return_softmax_lse
:
bool
|
None
=
False
,
return_softmax_lse
:
bool
|
None
=
False
,
s_aux
:
torch
.
Tensor
|
None
=
None
,
s_aux
:
torch
.
Tensor
|
None
=
None
,
return_attn_probs
:
bool
|
None
=
False
,
):
):
assert
cu_seqlens_k
is
not
None
or
seqused_k
is
not
None
,
(
assert
cu_seqlens_k
is
not
None
or
seqused_k
is
not
None
,
(
"cu_seqlens_k or seqused_k must be provided"
"cu_seqlens_k or seqused_k must be provided"
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
14771f71
...
@@ -1059,6 +1059,10 @@ except ImportError:
...
@@ -1059,6 +1059,10 @@ except ImportError:
"MLA models using TRITON_MLA will require flash_attn. "
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
"AITER_MLA backends use aiter kernels instead."
)
)
elif
current_platform
.
is_xpu
():
from
vllm._xpu_ops
import
xpu_ops
as
ops
flash_attn_varlen_func
=
ops
.
flash_attn_varlen_func
# type: ignore[no-redef]
def
dynamic_per_batched_tensor_quant
(
def
dynamic_per_batched_tensor_quant
(
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
14771f71
...
@@ -165,6 +165,16 @@ class QuantFP8(CustomOp):
...
@@ -165,6 +165,16 @@ class QuantFP8(CustomOp):
# Fallback to CUDA implementation
# Fallback to CUDA implementation
return
self
.
forward_cuda
(
x
,
scale
,
scale_ub
)
return
self
.
forward_cuda
(
x
,
scale
,
scale_ub
)
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
|
None
=
None
,
scale_ub
:
torch
.
Tensor
|
None
=
None
,
use_triton
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# XPU can use same code path as CUDA.
return
self
.
forward_cuda
(
x
,
scale
,
scale_ub
,
use_triton
)
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/platforms/xpu.py
View file @
14771f71
...
@@ -160,7 +160,6 @@ class XPUPlatform(Platform):
...
@@ -160,7 +160,6 @@ class XPUPlatform(Platform):
@
classmethod
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
# in V1(or with chunked prefill) block_size is 64
# in V1(or with chunked prefill) block_size is 64
if
cache_config
and
not
cache_config
.
user_specified_block_size
:
if
cache_config
and
not
cache_config
.
user_specified_block_size
:
...
@@ -209,17 +208,6 @@ class XPUPlatform(Platform):
...
@@ -209,17 +208,6 @@ class XPUPlatform(Platform):
if
vllm_config
.
kv_transfer_config
is
not
None
:
if
vllm_config
.
kv_transfer_config
is
not
None
:
vllm_config
.
kv_transfer_config
.
enable_permute_local_kv
=
True
vllm_config
.
kv_transfer_config
.
enable_permute_local_kv
=
True
if
model_config
and
model_config
.
use_mla
:
logger
.
info
(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
model_config
.
max_model_len
,
vllm_config
.
scheduler_config
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
# In some cases, the internal memory type cache can misdetect GPU
# In some cases, the internal memory type cache can misdetect GPU
# memory as host memory, also leading to invalid memory access.
# memory as host memory, also leading to invalid memory access.
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment