Commit a30b3ce2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_decode_attention.py

parent 343a10fa
......@@ -753,6 +753,12 @@ def decode_attention_v1(
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
......@@ -1146,11 +1152,17 @@ def decode_attentionv2_fwd(
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
......
......@@ -1064,6 +1064,12 @@ def decode_attention_v1(
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
......@@ -1545,11 +1551,17 @@ def decode_attention_fwd(
page_size,
logit_cap,
)'''
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment