# SPDX-License-Identifier: Apache-2.0 import pytest import torch import triton from vllm.attention.ops.triton_decode_attention import decode_attention_fwd, decode_attention_v1, decode_attention_v2 def cdiv(a, b): return (a + b - 1) // b @pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("L", [1027, 1025]) @pytest.mark.parametrize("H_Q", [32]) @pytest.mark.parametrize("H_KV", [32, 8]) @pytest.mark.parametrize("D_QK", [128, 192, 576]) @pytest.mark.parametrize("D_V", [128, 512]) @pytest.mark.parametrize("CACHE_SIZE", [16384]) @pytest.mark.parametrize("PAGE_SIZE", [1, 16]) def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): assert CACHE_SIZE % PAGE_SIZE == 0 dtype = torch.bfloat16 seq_len = L # This represents the number of tokens already in the sequence sm_scale = 1.0 / (D_QK**0.5) num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) # 向上取整:65, (1027+16-1)//16 req_to_page = torch.randint(0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size device="cuda") req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) # 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16]) req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( 1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") # k_buffer and v_buffer represent all previous tokens # Page size is 1. k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") b_seq_len = torch.full((B, ), seq_len, device="cuda") b_start_loc = torch.arange(0, k_buffer.shape[0] * PAGE_SIZE, k_buffer.shape[0] * PAGE_SIZE // q.shape[0], device="cuda").to(torch.int32) attn_logits_v1 = torch.empty( (q.shape[1], k_buffer.shape[0]*PAGE_SIZE), dtype=torch.float16, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda", ) quantiles = [0.5, 0.2, 0.8] # Call the original implementation. decode_attention_fwd( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, ) # Page size can be larger than 1. k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) o1 = torch.zeros_like(o) decode_attention_fwd( q, k_buffer, v_buffer, o1, req_to_page, b_seq_len, attn_logits, num_kv_splits, sm_scale, PAGE_SIZE, ) assert torch.allclose(o, o1) # v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda: # decode_attention_fwd( # q, # k_buffer, # v_buffer, # o1, # req_to_page, # b_seq_len, # attn_logits, # num_kv_splits, # sm_scale, # PAGE_SIZE, # ), quantiles=quantiles) # print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms) decode_attention_v1( q, k_buffer, v_buffer, o1, req_to_page, b_start_loc, b_seq_len, attn_logits_v1, num_kv_splits, sm_scale, PAGE_SIZE, ) assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) # v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda: # decode_attention_v1( # q, # k_buffer, # v_buffer, # o1, # req_to_page, # b_start_loc, # b_seq_len, # attn_logits_v1, # num_kv_splits, # sm_scale, # PAGE_SIZE, # ), quantiles=quantiles) # print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms) decode_attention_v2( q, k_buffer, v_buffer, o1, req_to_page, b_seq_len, attn_logits, num_kv_splits, sm_scale, PAGE_SIZE, ) assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) # v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda: # decode_attention_v2( # q, # k_buffer, # v_buffer, # o1, # req_to_page, # b_seq_len, # attn_logits, # num_kv_splits, # sm_scale, # PAGE_SIZE, # ), quantiles=quantiles) # print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)