Unverified Commit e06de7f0 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU] enable triton attention test on XPU by removing cuda device binding (#39627)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent cc3993b0
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
import pytest import pytest
import torch import torch
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
DEVICE_TYPE = current_platform.device_type
@pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025]) @pytest.mark.parametrize("L", [1027, 1025])
...@@ -25,33 +28,35 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -25,33 +28,35 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint( req_to_page = torch.randint(
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device=DEVICE_TYPE
) )
req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token + torch.arange(PAGE_SIZE, device=DEVICE_TYPE).view(
1, 1, -1
)
req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous() req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch # q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") q = torch.randn(B, H_Q, D_QK, dtype=dtype, device=DEVICE_TYPE)
# k_buffer and v_buffer represent all previous tokens # k_buffer and v_buffer represent all previous tokens
# Page size is 1. # Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device=DEVICE_TYPE)
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device=DEVICE_TYPE)
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=DEVICE_TYPE)
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda") lse = torch.zeros(B, H_Q, dtype=dtype, device=DEVICE_TYPE)
b_seq_len = torch.full((B,), seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device=DEVICE_TYPE)
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), (B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device=DEVICE_TYPE,
) )
# Call the original implementation. # Call the original implementation.
...@@ -127,25 +132,27 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -127,25 +132,27 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint( req_to_page = torch.randint(
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device=DEVICE_TYPE
) )
req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token + torch.arange(PAGE_SIZE, device=DEVICE_TYPE).view(
1, 1, -1
)
req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous() req_to_token = req_to_token[:, :seq_len].contiguous()
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") q = torch.randn(B, H_Q, D_QK, dtype=dtype, device=DEVICE_TYPE)
# Create BF16 K/V as reference # Create BF16 K/V as reference
k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device=DEVICE_TYPE)
v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device=DEVICE_TYPE)
# --- BF16 reference --- # --- BF16 reference ---
o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device=DEVICE_TYPE)
lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda") lse_ref = torch.zeros(B, H_Q, dtype=dtype, device=DEVICE_TYPE)
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device=DEVICE_TYPE
) )
if PAGE_SIZE == 1: if PAGE_SIZE == 1:
...@@ -156,7 +163,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -156,7 +163,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_ref, o_ref,
lse_ref, lse_ref,
req_to_token, req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"), b_seq_len=torch.full((B,), seq_len, device=DEVICE_TYPE),
attn_logits=attn_logits, attn_logits=attn_logits,
num_kv_splits=num_kv_splits, num_kv_splits=num_kv_splits,
sm_scale=sm_scale, sm_scale=sm_scale,
...@@ -171,7 +178,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -171,7 +178,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_ref, o_ref,
lse_ref, lse_ref,
req_to_page, req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"), b_seq_len=torch.full((B,), seq_len, device=DEVICE_TYPE),
attn_logits=attn_logits, attn_logits=attn_logits,
num_kv_splits=num_kv_splits, num_kv_splits=num_kv_splits,
sm_scale=sm_scale, sm_scale=sm_scale,
...@@ -182,10 +189,10 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -182,10 +189,10 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
k_fp8, k_scale = _quantize_to_fp8(k_bf16) k_fp8, k_scale = _quantize_to_fp8(k_bf16)
v_fp8, v_scale = _quantize_to_fp8(v_bf16) v_fp8, v_scale = _quantize_to_fp8(v_bf16)
o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device=DEVICE_TYPE)
lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda") lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device=DEVICE_TYPE)
attn_logits_fp8 = torch.empty( attn_logits_fp8 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device=DEVICE_TYPE
) )
if PAGE_SIZE == 1: if PAGE_SIZE == 1:
...@@ -196,7 +203,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -196,7 +203,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_fp8, o_fp8,
lse_fp8, lse_fp8,
req_to_token, req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"), b_seq_len=torch.full((B,), seq_len, device=DEVICE_TYPE),
attn_logits=attn_logits_fp8, attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits, num_kv_splits=num_kv_splits,
sm_scale=sm_scale, sm_scale=sm_scale,
...@@ -213,7 +220,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE) ...@@ -213,7 +220,7 @@ def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE)
o_fp8, o_fp8,
lse_fp8, lse_fp8,
req_to_page, req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"), b_seq_len=torch.full((B,), seq_len, device=DEVICE_TYPE),
attn_logits=attn_logits_fp8, attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits, num_kv_splits=num_kv_splits,
sm_scale=sm_scale, sm_scale=sm_scale,
......
...@@ -5,8 +5,11 @@ import pytest ...@@ -5,8 +5,11 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
DEVICE_TYPE = current_platform.device_type
def ref_masked_attention( def ref_masked_attention(
q: torch.Tensor, q: torch.Tensor,
...@@ -92,17 +95,19 @@ def test_context_attention( ...@@ -92,17 +95,19 @@ def test_context_attention(
torch.manual_seed(42) torch.manual_seed(42)
# Generate random sequence lengths for each batch # Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda") seq_lens = torch.randint(
max_seq_len // 2, max_seq_len + 1, (B,), device=DEVICE_TYPE
)
total_tokens = seq_lens.sum().item() total_tokens = seq_lens.sum().item()
# Create batch start locations # Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda") b_start_loc = torch.zeros(B, dtype=torch.int32, device=DEVICE_TYPE)
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors # Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda") q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device=DEVICE_TYPE)
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=DEVICE_TYPE)
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=DEVICE_TYPE)
o = torch.zeros_like(q) o = torch.zeros_like(q)
# Call Triton kernel # Call Triton kernel
...@@ -169,17 +174,19 @@ def test_context_attention_sliding_window( ...@@ -169,17 +174,19 @@ def test_context_attention_sliding_window(
torch.manual_seed(42) torch.manual_seed(42)
# Generate random sequence lengths for each batch # Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda") seq_lens = torch.randint(
max_seq_len // 2, max_seq_len + 1, (B,), device=DEVICE_TYPE
)
total_tokens = seq_lens.sum().item() total_tokens = seq_lens.sum().item()
# Create batch start locations # Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda") b_start_loc = torch.zeros(B, dtype=torch.int32, device=DEVICE_TYPE)
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors # Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda") q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device=DEVICE_TYPE)
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=DEVICE_TYPE)
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=DEVICE_TYPE)
o = torch.zeros_like(q) o = torch.zeros_like(q)
# Call Triton kernel # Call Triton kernel
......
...@@ -10,6 +10,8 @@ from vllm.utils.math_utils import next_power_of_2 ...@@ -10,6 +10,8 @@ from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.ops.triton_unified_attention import unified_attention from vllm.v1.attention.ops.triton_unified_attention import unified_attention
DEVICE_TYPE = current_platform.device_type
NUM_HEADS = [(4, 4), (8, 2), (5, 1)] NUM_HEADS = [(4, 4), (8, 2), (5, 1)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16] BLOCK_SIZES = [16]
...@@ -114,7 +116,7 @@ def test_triton_unified_attn( ...@@ -114,7 +116,7 @@ def test_triton_unified_attn(
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
seq_threshold_3D: int, seq_threshold_3D: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
set_random_seed(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
...@@ -249,7 +251,7 @@ def test_triton_unified_attn_fp16_input_fp8_output( ...@@ -249,7 +251,7 @@ def test_triton_unified_attn_fp16_input_fp8_output(
seq_threshold_3D: int, seq_threshold_3D: int,
) -> None: ) -> None:
"""Test with fp16 input and fp8 output using output_scale.""" """Test with fp16 input and fp8 output using output_scale."""
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
set_random_seed(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment