Commit e1600abd authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla kernel and configs

parent 5c241fa9
...@@ -13,7 +13,6 @@ def cdiv(a, b): ...@@ -13,7 +13,6 @@ def cdiv(a, b):
@pytest.mark.parametrize("B", [1]) @pytest.mark.parametrize("B", [1])
# @pytest.mark.parametrize("L", [100]) # @pytest.mark.parametrize("L", [100])
# @pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000])
@pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500]) @pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500])
@pytest.mark.parametrize("H_Q", [4, 8, 16]) @pytest.mark.parametrize("H_Q", [4, 8, 16])
@pytest.mark.parametrize("H_KV", [1]) @pytest.mark.parametrize("H_KV", [1])
......
...@@ -491,14 +491,6 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -491,14 +491,6 @@ def _decode_v1_kernel_stage1_use_tc(
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
...@@ -515,18 +507,6 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -515,18 +507,6 @@ def _decode_v1_kernel_stage1_use_tc(
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
...@@ -1164,7 +1144,17 @@ def decode_attentionv2_fwd( ...@@ -1164,7 +1144,17 @@ def decode_attentionv2_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, k_buffer.shape[0] * page_size, k_buffer.shape[0] * page_size // q.shape[0], device="cuda").to(torch.int32) num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if kv_group_num == 1: if kv_group_num == 1:
# MHA # MHA
decode_attention_fwd_normal( decode_attention_fwd_normal(
...@@ -1189,7 +1179,7 @@ def decode_attentionv2_fwd( ...@@ -1189,7 +1179,7 @@ def decode_attentionv2_fwd(
o, o,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits_v1,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
......
...@@ -488,12 +488,12 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -488,12 +488,12 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
# version = 'das.opt1.cust2.' + sha[:7] version = 'das.opt1.alpha.' + sha[:7]
version = 'das.opt1.' + sha[:7] # version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
# version = 'das.opt1.cust2' version = 'das.opt1.alpha'
version = 'das.opt1' # version = 'das.opt1'
# dtk version # dtk version
......
...@@ -757,14 +757,6 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -757,14 +757,6 @@ def _decode_v1_kernel_stage1_use_tc(
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
...@@ -781,18 +773,6 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -781,18 +773,6 @@ def _decode_v1_kernel_stage1_use_tc(
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
...@@ -1563,6 +1543,16 @@ def decode_attention_fwd( ...@@ -1563,6 +1543,16 @@ def decode_attention_fwd(
page_size, page_size,
logit_cap, logit_cap,
)''' )'''
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if best_config['kernel_kind'] == 'v1_2stages_tc': if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
...@@ -1592,7 +1582,7 @@ def decode_attention_fwd( ...@@ -1592,7 +1582,7 @@ def decode_attention_fwd(
o, o,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits_v2,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config['best_config'], best_config=best_config['best_config'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment