Commit 6f49c1ed authored by zhuwenwen's avatar zhuwenwen
Browse files

back to mla v2

parent cf28e5a4
......@@ -213,4 +213,3 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1)
#**************save config**************#
\ No newline at end of file
......@@ -37,10 +37,7 @@ import triton.language as tl
from vllm.platforms import current_platform
is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
logger = logging.getLogger(__name__)
......@@ -760,12 +757,6 @@ def decode_attention_v1(
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
......@@ -831,15 +822,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk
# off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# )
# qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# )
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
......@@ -868,19 +859,11 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
......@@ -1018,6 +1001,7 @@ def _decode_v2_stage1_use_tc(
@triton.autotune(
configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1),
......@@ -1160,20 +1144,11 @@ def decode_attentionv2_fwd(
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
......@@ -1263,4 +1238,3 @@ def decode_attentionv1_fwd(
logit_cap,
)
return v1_tc_stage1_best_config, v1_tc_stage2_best_config
\ No newline at end of file
......@@ -39,10 +39,7 @@ from vllm import envs
# from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"1"
os.environ["TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"]="0"
os.environ["TRITON_DEFAULT_ENABLE_NUM_VGPRS512"] = "1"
os.environ["MLIR_ENABLE_DUMP"] = "0"
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
logger = logging.getLogger(__name__)
......@@ -1071,12 +1068,6 @@ def decode_attention_v1(
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
......@@ -1142,15 +1133,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# if BLOCK_DPE > 0:
# offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
# mask_dpe = offs_dpe < Lk
# off_qpe = (
# cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
# )
# qpe = tl.load(
# Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
# )
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
......@@ -1180,21 +1171,11 @@ def _decode_v2_kernel_stage1_use_tc(
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
......@@ -1346,6 +1327,7 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
......@@ -1510,26 +1492,7 @@ def decode_attention_fwd(
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1]*page_size, req_to_token.shape[0]*req_to_token.shape[1]* page_size // q.shape[0], device="cuda").to(torch.int32)
current_device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device)
cu_num = props.multi_processor_count
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < cu_num:
num_kv_splits = (cu_num - 1 + grid_num) // grid_num
#[TODO] The relationship between L and block is to be analyzed
if L >= 2048:
num_kv_splits = (2 * cu_num - 1 + grid_num) // grid_num
if L >= 4096:
num_kv_splits = (4 * cu_num - 1 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32)
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
......@@ -1580,6 +1543,16 @@ def decode_attention_fwd(
page_size,
logit_cap,
)'''
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v2 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment