Commit 8b5a09f6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_decode_attention.py

parent f6044f1a
......@@ -40,6 +40,7 @@ is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__)
......@@ -830,15 +831,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
# if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk
# off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# )
# qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
......@@ -867,11 +868,19 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
......@@ -1009,7 +1018,6 @@ def _decode_v2_stage1_use_tc(
@triton.autotune(
configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1),
......@@ -1163,6 +1171,9 @@ def decode_attentionv2_fwd(
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
......
......@@ -42,6 +42,7 @@ is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__)
......@@ -1141,15 +1142,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
# if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk
# off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# )
# qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
......@@ -1179,11 +1180,21 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
......@@ -1335,7 +1346,6 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
......@@ -1500,7 +1510,26 @@ def decode_attention_fwd(
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32)
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1]*page_size, req_to_token.shape[0]*req_to_token.shape[1]* page_size // q.shape[0], device="cuda").to(torch.int32)
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
......@@ -1551,22 +1580,6 @@ def decode_attention_fwd(
page_size,
logit_cap,
)'''
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty(
......@@ -1589,7 +1602,7 @@ def decode_attention_fwd(
logit_cap=logit_cap,
)
elif best_config['kernel_kind'] == 'v2_tc':
decode_attention_v2(
decode_attention_v1(
q,
k_buffer,
v_buffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment