Commit 6f49c1ed authored by zhuwenwen's avatar zhuwenwen
Browse files

back to mla v2

parent cf28e5a4
...@@ -213,4 +213,3 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -213,4 +213,3 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
with open(file_name, 'w') as file: with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1) json.dump(config_info, file, indent=1)
#**************save config**************# #**************save config**************#
...@@ -37,10 +37,7 @@ import triton.language as tl ...@@ -37,10 +37,7 @@ import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -760,12 +757,6 @@ def decode_attention_v1( ...@@ -760,12 +757,6 @@ def decode_attention_v1(
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
...@@ -831,15 +822,15 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -831,15 +822,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# if BLOCK_DPE > 0: if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk mask_dpe = offs_dpe < Lk
# off_qpe = ( off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# ) )
# qpe = tl.load( qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# ) )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id split_kv_start = kv_len_per_split * split_kv_id
...@@ -868,19 +859,11 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -868,19 +859,11 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype)) qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = ( offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh + cur_kv_head * stride_buf_kh
+ offs_dpe[:, None] + offs_dpe[:, None]
) )
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load( kpe = tl.load(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
...@@ -1018,6 +1001,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1018,6 +1001,7 @@ def _decode_v2_stage1_use_tc(
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=1, num_stages=1), triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1), triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1), triton.Config({}, num_warps=4, num_stages=1),
...@@ -1160,20 +1144,11 @@ def decode_attentionv2_fwd( ...@@ -1160,20 +1144,11 @@ def decode_attentionv2_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16) num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0] grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num: if grid_num * num_kv_splits < 128:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num num_kv_splits = (127 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1), (q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32, dtype=torch.float32,
...@@ -1263,4 +1238,3 @@ def decode_attentionv1_fwd( ...@@ -1263,4 +1238,3 @@ def decode_attentionv1_fwd(
logit_cap, logit_cap,
) )
return v1_tc_stage1_best_config, v1_tc_stage2_best_config return v1_tc_stage1_best_config, v1_tc_stage2_best_config
\ No newline at end of file
...@@ -39,10 +39,7 @@ from vllm import envs ...@@ -39,10 +39,7 @@ from vllm import envs
# from ..backends.triton_config import KERNLE_KINDS # from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1071,12 +1068,6 @@ def decode_attention_v1( ...@@ -1071,12 +1068,6 @@ def decode_attention_v1(
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
...@@ -1142,15 +1133,15 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1142,15 +1133,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# if BLOCK_DPE > 0: if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk mask_dpe = offs_dpe < Lk
# off_qpe = ( off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# ) )
# qpe = tl.load( qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# ) )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id split_kv_start = kv_len_per_split * split_kv_id
...@@ -1180,21 +1171,11 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1180,21 +1171,11 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype)) qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = ( offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh + cur_kv_head * stride_buf_kh
+ offs_dpe[:, None] + offs_dpe[:, None]
) )
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load( kpe = tl.load(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
...@@ -1346,6 +1327,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1346,6 +1327,7 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1), # triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1), # triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1), # triton.Config({}, num_warps=8, num_stages=1),
...@@ -1510,26 +1492,7 @@ def decode_attention_fwd( ...@@ -1510,26 +1492,7 @@ def decode_attention_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1]*page_size, req_to_token.shape[0]*req_to_token.shape[1]* page_size // q.shape[0], device="cuda").to(torch.int32) b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32)
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if kv_group_num == 1: if kv_group_num == 1:
# MHA # MHA
decode_attention_fwd_normal( decode_attention_fwd_normal(
...@@ -1580,6 +1543,16 @@ def decode_attention_fwd( ...@@ -1580,6 +1543,16 @@ def decode_attention_fwd(
page_size, page_size,
logit_cap, logit_cap,
)''' )'''
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if best_config['kernel_kind'] == 'v1_2stages_tc': if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment