Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e06de7f0
Unverified
Commit
e06de7f0
authored
Apr 20, 2026
by
Yan Ma
Committed by
GitHub
Apr 20, 2026
Browse files
[XPU] enable triton attention test on XPU by removing cuda device binding (#39627)
Signed-off-by:
Yan Ma
<
yan.ma@intel.com
>
parent
cc3993b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
36 deletions
+52
-36
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+31
-24
tests/kernels/attention/test_triton_prefill_attention.py
tests/kernels/attention/test_triton_prefill_attention.py
+17
-10
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+4
-2
No files found.
tests/kernels/attention/test_triton_decode_attention.py
View file @
e06de7f0
...
...
@@ -4,9 +4,12 @@
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.ops.triton_decode_attention
import
decode_attention_fwd
DEVICE_TYPE
=
current_platform
.
device_type
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1027
,
1025
])
...
...
@@ -25,33 +28,35 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
"cuda"
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
DEVICE_TYPE
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
DEVICE_TYPE
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
lse
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
lse
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
DEVICE_TYPE
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
)
# Call the original implementation.
...
...
@@ -127,25 +132,27 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
"cuda"
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
DEVICE_TYPE
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
DEVICE_TYPE
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
# Create BF16 K/V as reference
k_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
k_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
v_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
# --- BF16 reference ---
o_ref
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
lse_ref
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
o_ref
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
lse_ref
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
)
if
PAGE_SIZE
==
1
:
...
...
@@ -156,7 +163,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_ref
,
lse_ref
,
req_to_token
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
DEVICE_TYPE
),
attn_logits
=
attn_logits
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
...
...
@@ -171,7 +178,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_ref
,
lse_ref
,
req_to_page
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
DEVICE_TYPE
),
attn_logits
=
attn_logits
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
...
...
@@ -182,10 +189,10 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
k_fp8
,
k_scale
=
_quantize_to_fp8
(
k_bf16
)
v_fp8
,
v_scale
=
_quantize_to_fp8
(
v_bf16
)
o_fp8
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
lse_fp8
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
o_fp8
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
lse_fp8
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
attn_logits_fp8
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
)
if
PAGE_SIZE
==
1
:
...
...
@@ -196,7 +203,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_fp8
,
lse_fp8
,
req_to_token
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
DEVICE_TYPE
),
attn_logits
=
attn_logits_fp8
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
...
...
@@ -213,7 +220,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_fp8
,
lse_fp8
,
req_to_page
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
DEVICE_TYPE
),
attn_logits
=
attn_logits_fp8
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
...
...
tests/kernels/attention/test_triton_prefill_attention.py
View file @
e06de7f0
...
...
@@ -5,8 +5,11 @@ import pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.ops.triton_prefill_attention
import
context_attention_fwd
DEVICE_TYPE
=
current_platform
.
device_type
def
ref_masked_attention
(
q
:
torch
.
Tensor
,
...
...
@@ -92,17 +95,19 @@ def test_context_attention(
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
DEVICE_TYPE
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
...
...
@@ -169,17 +174,19 @@ def test_context_attention_sliding_window(
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
DEVICE_TYPE
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
DEVICE_TYPE
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
...
...
tests/kernels/attention/test_triton_unified_attention.py
View file @
e06de7f0
...
...
@@ -10,6 +10,8 @@ from vllm.utils.math_utils import next_power_of_2
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.ops.triton_unified_attention
import
unified_attention
DEVICE_TYPE
=
current_platform
.
device_type
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
5
,
1
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
]
...
...
@@ -114,7 +116,7 @@ def test_triton_unified_attn(
q_dtype
:
torch
.
dtype
|
None
,
seq_threshold_3D
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
DEVICE_TYPE
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
...
...
@@ -249,7 +251,7 @@ def test_triton_unified_attn_fp16_input_fp8_output(
seq_threshold_3D
:
int
,
)
->
None
:
"""Test with fp16 input and fp8 output using output_scale."""
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
DEVICE_TYPE
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment