Commit 8b5a09f6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_decode_attention.py

parent f6044f1a
...@@ -40,6 +40,7 @@ is_hip_ = current_platform.is_rocm() ...@@ -40,6 +40,7 @@ is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0" os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1" os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -830,15 +831,15 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -830,15 +831,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0: # if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) # offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk # mask_dpe = offs_dpe < Lk
off_qpe = ( # off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] # cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
) # )
qpe = tl.load( # qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 # Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
) # )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id split_kv_start = kv_len_per_split * split_kv_id
...@@ -867,11 +868,19 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -867,11 +868,19 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype)) qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = ( offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh + cur_kv_head * stride_buf_kh
+ offs_dpe[:, None] + offs_dpe[:, None]
) )
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load( kpe = tl.load(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
...@@ -1009,7 +1018,6 @@ def _decode_v2_stage1_use_tc( ...@@ -1009,7 +1018,6 @@ def _decode_v2_stage1_use_tc(
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=1, num_stages=1), triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1), triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1), triton.Config({}, num_warps=4, num_stages=1),
...@@ -1163,6 +1171,9 @@ def decode_attentionv2_fwd( ...@@ -1163,6 +1171,9 @@ def decode_attentionv2_fwd(
#[TODO] The relationship between L and block is to be analyzed #[TODO] The relationship between L and block is to be analyzed
if L >= 2048: if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1), (q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32, dtype=torch.float32,
......
...@@ -42,6 +42,7 @@ is_hip_ = current_platform.is_rocm() ...@@ -42,6 +42,7 @@ is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0" os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1" os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1141,15 +1142,15 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1141,15 +1142,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0: # if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) # offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk # mask_dpe = offs_dpe < Lk
off_qpe = ( # off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] # cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
) # )
qpe = tl.load( # qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 # Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
) # )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id split_kv_start = kv_len_per_split * split_kv_id
...@@ -1179,11 +1180,21 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1179,11 +1180,21 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype)) qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = ( offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh + cur_kv_head * stride_buf_kh
+ offs_dpe[:, None] + offs_dpe[:, None]
) )
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load( kpe = tl.load(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
...@@ -1335,7 +1346,6 @@ def _decode_v2_stage1_use_tc( ...@@ -1335,7 +1346,6 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1), # triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1), # triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1), # triton.Config({}, num_warps=8, num_stages=1),
...@@ -1500,7 +1510,26 @@ def decode_attention_fwd( ...@@ -1500,7 +1510,26 @@ def decode_attention_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32) b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1]*page_size, req_to_token.shape[0]*req_to_token.shape[1]* page_size // q.shape[0], device="cuda").to(torch.int32)
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if kv_group_num == 1: if kv_group_num == 1:
# MHA # MHA
decode_attention_fwd_normal( decode_attention_fwd_normal(
...@@ -1551,22 +1580,6 @@ def decode_attention_fwd( ...@@ -1551,22 +1580,6 @@ def decode_attention_fwd(
page_size, page_size,
logit_cap, logit_cap,
)''' )'''
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if best_config['kernel_kind'] == 'v1_2stages_tc': if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
...@@ -1589,7 +1602,7 @@ def decode_attention_fwd( ...@@ -1589,7 +1602,7 @@ def decode_attention_fwd(
logit_cap=logit_cap, logit_cap=logit_cap,
) )
elif best_config['kernel_kind'] == 'v2_tc': elif best_config['kernel_kind'] == 'v2_tc':
decode_attention_v2( decode_attention_v1(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment