Unverified Commit df191254 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize MLA/GQA/MQA Triton decoding (#1138)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent b997a18d
......@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
att_stride_h,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
......@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
......@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
offs_buf_kpe = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[None, :]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value += tl.sum(qpe[None, :] * kpe, 1)
att_value *= sm_scale
if logit_cap > 0:
......@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256, 576}
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lk
BLOCK_DPE = 0
assert Lk in {16, 32, 64, 128, 256}
batch, head_num = B_req_idx.shape[0], q.shape[1]
......@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps,
......@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
)
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
logit_cap: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
start_n = tl.program_id(2)
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
REDUCE_TRITON_TYPE
)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
offs_buf_k = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=offs_n_new[None, :] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
REDUCE_TRITON_TYPE
)
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[None, :] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
offs_o = cur_head[:, None] * att_stride_h + (
cur_batch_in_all_start_index + offs_n[None, :]
)
tl.store(
Att_Out + offs_o,
qk,
mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
)
@triton.jit
def _fwd_grouped_kernel_stage2(
Logics,
V_Buffer,
Out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
stride_logic_h,
stride_buf_vbs,
stride_buf_vh,
stride_obs,
stride_oh,
stride_req_to_token_b,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
mask_h = mask_h & (cur_head < q_head_num)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
v_ptrs = V_Buffer + offs_buf_v
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len,
other=0,
)
offs_qk = cur_head[:, None] * stride_logic_h + (
cur_batch_start_loc + start_n + offs_n[None, :]
)
qk = tl.load(
Logics + offs_qk,
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
other=float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
e_sum = e_sum * old_scale + tl.sum(p, 1)
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
p = p.to(v.dtype)
acc = acc * old_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
acc = acc / e_sum[:, None]
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=mask_h[:, None])
def _decode_grouped_att_m_fwd(
q,
k_buffer,
att_out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
max_len_in_batch,
sm_scale,
logit_cap,
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256, 576}
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lk
BLOCK_DPE = 0
batch, head_num = B_req_idx.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
triton.cdiv(max_len_in_batch, BLOCK),
)
num_warps = 4
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
)
def _decode_grouped_softmax_reducev_fwd(
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
):
BLOCK = 128
batch, head_num = b_seq_len.shape[0], logics.shape[0]
kv_group_num = logics.shape[0] // v_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
num_warps = 8
_fwd_grouped_kernel_stage2[grid](
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
num_warps=num_warps,
num_stages=1,
)
def decode_attention_fwd(
q,
k_buffer,
......@@ -316,24 +577,51 @@ def decode_attention_fwd(
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
)
_decode_att_m_fwd(
q,
k_buffer,
att_m,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
att_m,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
kv_group_num = q.shape[1] // v_buffer.shape[1]
if kv_group_num == 1:
# MHA
_decode_att_m_fwd(
q,
k_buffer,
att_m,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
att_m,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
else:
# GQA/MQA/MLA
_decode_grouped_att_m_fwd(
q,
k_buffer,
att_m,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
att_m,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment