Commit e4a9c2cd authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla optest

parent 52121d00
...@@ -688,7 +688,7 @@ package_data = { ...@@ -688,7 +688,7 @@ package_data = {
"model_executor/layers/fused_moe/configs/*.json", "model_executor/layers/fused_moe/configs/*.json",
"model_executor/layers/quantization/utils/configs/*.json", "model_executor/layers/quantization/utils/configs/*.json",
"benchmarks/*.py", "benchmarks/*.py",
"model_executor/layers/quantization/configs/w8a8/*.json", "attention/backends/configs/*.json",
"model_executor/layers/quantization/configs/awq/*.json" "model_executor/layers/quantization/configs/awq/*.json"
] ]
} }
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd, decode_attention_v1, decode_attention_v2
def cdiv(a, b): def cdiv(a, b):
return (a + b - 1) // b return (a + b - 1) // b
...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale = 1.0 / (D_QK**0.5) sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8 num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) # 向上取整:65, (1027+16-1)//16
req_to_page = torch.randint(0, req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE, CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), (B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda") device="cuda")
req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) # 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1) 1, 1, -1)
req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token.view(B, -1)
...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda") b_seq_len = torch.full((B, ), seq_len, device="cuda")
b_start_loc = torch.arange(0, k_buffer.shape[0] * PAGE_SIZE, k_buffer.shape[0] * PAGE_SIZE // q.shape[0], device="cuda").to(torch.int32)
attn_logits_v1 = torch.empty(
(q.shape[1], k_buffer.shape[0]*PAGE_SIZE),
dtype=torch.float16,
device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), (B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
) )
quantiles = [0.5, 0.2, 0.8]
# Call the original implementation. # Call the original implementation.
decode_attention_fwd( decode_attention_fwd(
...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale, sm_scale,
PAGE_SIZE, PAGE_SIZE,
) )
assert torch.allclose(o, o1) assert torch.allclose(o, o1)
# v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_fwd(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms)
decode_attention_v1(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
decode_attention_v2(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
...@@ -1420,7 +1420,7 @@ def decode_attention_v2( ...@@ -1420,7 +1420,7 @@ def decode_attention_v2(
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
def decode_attention_fwd( def decode_attention_fwd(
q, q,
k_buffer, k_buffer,
...@@ -1455,21 +1455,7 @@ def decode_attention_fwd( ...@@ -1455,21 +1455,7 @@ def decode_attention_fwd(
) )
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
if not envs.VLLM_USE_TRITON_OPT_MLA: if envs.VLLM_USE_TRITON_OPT_MLA:
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
decode_attention_v2( decode_attention_v2(
q, q,
k_buffer, k_buffer,
...@@ -1501,7 +1487,7 @@ def decode_attention_fwd( ...@@ -1501,7 +1487,7 @@ def decode_attention_fwd(
# page_size, # page_size,
# logit_cap, # logit_cap,
# ) # )
# if best_config['kernel_kind'] == 'v1_2stages_tc': # if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty( # attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size), # (q.shape[1],k_buffer.shape[0]*page_size),
...@@ -1538,4 +1524,18 @@ def decode_attention_fwd( ...@@ -1538,4 +1524,18 @@ def decode_attention_fwd(
# logit_cap, # logit_cap,
# ) # )
# else: # else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind']) # print("Unknown mla kernel kind: ", best_config['kernel_kind'])
\ No newline at end of file else:
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment