Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16ded21e
Unverified
Commit
16ded21e
authored
Sep 04, 2025
by
Kunshang Ji
Committed by
GitHub
Sep 04, 2025
Browse files
[XPU] support Triton Attention backend on Intel GPU (#24149)
Signed-off-by:
Kunshang Ji
<
kunshang.ji@intel.com
>
parent
2b30afa4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
15 deletions
+49
-15
.buildkite/scripts/hardware_ci/run-xpu-test.sh
.buildkite/scripts/hardware_ci/run-xpu-test.sh
+5
-4
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+2
-3
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+6
-1
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+26
-2
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+10
-5
No files found.
.buildkite/scripts/hardware_ci/run-xpu-test.sh
View file @
16ded21e
...
@@ -30,10 +30,11 @@ docker run \
...
@@ -30,10 +30,11 @@ docker run \
bash
-c
'
bash
-c
'
set -e
set -e
echo $ZE_AFFINITY_MASK
echo $ZE_AFFINITY_MASK
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
cd tests
cd tests
pytest -v -s v1/core
pytest -v -s v1/core
pytest -v -s v1/engine
pytest -v -s v1/engine
...
...
vllm/_ipex_ops.py
View file @
16ded21e
...
@@ -242,10 +242,9 @@ class ipex_ops:
...
@@ -242,10 +242,9 @@ class ipex_ops:
k_scale_float
:
float
=
1.0
,
k_scale_float
:
float
=
1.0
,
v_scale_float
:
float
=
1.0
,
v_scale_float
:
float
=
1.0
,
)
->
None
:
)
->
None
:
assert
kv_cache_dtype
==
"auto"
# TODO: support FP8 kv cache.
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache_flash
(
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale_float
,
v_scale_float
)
@
staticmethod
@
staticmethod
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
...
...
vllm/attention/ops/paged_attn.py
View file @
16ded21e
...
@@ -6,9 +6,14 @@ from typing import List, Optional, Tuple
...
@@ -6,9 +6,14 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
vllm
import
_
cu
stom_ops
as
ops
from
vllm
.platforms
import
cu
rrent_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
...
...
vllm/platforms/xpu.py
View file @
16ded21e
...
@@ -37,14 +37,38 @@ class XPUPlatform(Platform):
...
@@ -37,14 +37,38 @@ class XPUPlatform(Platform):
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
)
->
str
:
has_sink
:
bool
)
->
str
:
if
selected_backend
is
not
None
and
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
use_v1
=
envs
.
VLLM_USE_V1
use_v1
=
envs
.
VLLM_USE_V1
if
not
use_v1
:
if
not
use_v1
:
raise
ValueError
(
"XPU backend only supports V1."
)
raise
ValueError
(
"XPU backend only supports V1."
)
TRITON_ATTN_VLLM_V1
=
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# noqa: E501
FLASH_ATTN_V1
=
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
# noqa: E501
if
selected_backend
==
_Backend
.
TRITON_ATTN_VLLM_V1
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
TRITON_ATTN_VLLM_V1
elif
selected_backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
FLASH_ATTN_V1
elif
selected_backend
:
raise
ValueError
(
f
"Invalid attention backend for
{
cls
.
device_name
}
, "
f
"with use_v1:
{
use_v1
}
use_mla:
{
use_mla
}
"
)
logger
.
info
(
"Using Flash Attention backend on V1 engine."
)
logger
.
info
(
"Using Flash Attention backend on V1 engine."
)
return
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
return
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@
classmethod
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
"""
Check if the kv_cache_dtype is supported.
XPU only support fp8 kv cache with triton backend.
"""
if
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
and
\
envs
.
VLLM_ATTENTION_BACKEND
==
"TRITON_ATTN_VLLM_V1"
:
return
kv_cache_dtype
in
[
"fp8_e4m3"
,
"fp8_e5m2"
,
"fp8"
]
return
False
@
classmethod
@
classmethod
def
set_device
(
cls
,
device
:
torch
.
device
)
->
None
:
def
set_device
(
cls
,
device
:
torch
.
device
)
->
None
:
"""
"""
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
16ded21e
...
@@ -7,7 +7,6 @@ from typing import ClassVar, Optional
...
@@ -7,7 +7,6 @@ from typing import ClassVar, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
...
@@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
...
@@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
else
:
else
:
torch
.
ops
.
_C_cache_
ops
.
reshape_and_cache_flash
(
ops
.
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
...
@@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
num_tokens
,
num_heads
,
head_size
=
query
.
shape
num_tokens
,
num_heads
,
head_size
=
query
.
shape
assert
layer
.
_q_scale
==
1.0
,
\
assert
layer
.
_q_scale
==
1.0
,
\
"A non 1.0 q_scale is not currently supported."
"A non 1.0 q_scale is not currently supported."
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_cuda
():
# Skip Q quantization on ROCm, since dequantizing back to
# Skip Q quantization on ROCm and XPU, enable this on cuda
# f32 in the attention kernel is not supported.
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query
,
_
=
ops
.
scaled_fp8_quant
(
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
(
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment