Commit a30b3ce2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_decode_attention.py

parent 343a10fa
...@@ -753,6 +753,12 @@ def decode_attention_v1( ...@@ -753,6 +753,12 @@ def decode_attention_v1(
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2), triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2), triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2), triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
...@@ -1146,11 +1152,17 @@ def decode_attentionv2_fwd( ...@@ -1146,11 +1152,17 @@ def decode_attentionv2_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16) num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0] grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128: if grid_num * num_kv_splits < cu_num:
num_kv_splits = (127 + grid_num) // grid_num num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1), (q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32, dtype=torch.float32,
......
...@@ -1064,6 +1064,12 @@ def decode_attention_v1( ...@@ -1064,6 +1064,12 @@ def decode_attention_v1(
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2), # triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2), # triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2), # triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
...@@ -1545,11 +1551,17 @@ def decode_attention_fwd( ...@@ -1545,11 +1551,17 @@ def decode_attention_fwd(
page_size, page_size,
logit_cap, logit_cap,
)''' )'''
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16) num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0] grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128: if grid_num * num_kv_splits < cu_num:
num_kv_splits = (127 + grid_num) // grid_num num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty( attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1), (q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32, dtype=torch.float32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment