Unverified Commit c5f86501 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix grid size in Triton decoding kernel (#2134)

parent d98fa1e9
......@@ -50,12 +50,13 @@ def _fwd_kernel_stage1(
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SPLIT_K: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
split_k_id = tl.program_id(2)
reduce_dtype = Att_Out.dtype.element_ty
cur_kv_head = cur_head // kv_group_num
......@@ -65,22 +66,18 @@ def _fwd_kernel_stage1(
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q).to(reduce_dtype)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
mask=offs_n < split_k_end,
other=0,
)
offs_buf_k = (
......@@ -90,7 +87,7 @@ def _fwd_kernel_stage1(
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
other=0.0,
).to(reduce_dtype)
att_value = tl.sum(q[None, :] * k, 1)
......@@ -100,7 +97,7 @@ def _fwd_kernel_stage1(
att_value = logit_cap * tanh(att_value / logit_cap)
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)
@triton.jit
......@@ -189,11 +186,12 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
SPLIT_K = 8
Lk = k_buffer.shape[-1]
batch, head_num = B_req_idx.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
grid = (batch, head_num, SPLIT_K)
kv_group_num = q.shape[1] // k_buffer.shape[1]
if kv_group_num == 1:
......@@ -221,6 +219,7 @@ def _decode_att_m_fwd(
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK,
SPLIT_K=SPLIT_K,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
......@@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
SPLIT_K: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
start_n = tl.program_id(2)
split_k_id = tl.program_id(2)
reduce_dtype = Att_Out.dtype.element_ty
......@@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1(
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
).to(reduce_dtype)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
mask=offs_n < split_k_end,
other=0,
)
offs_buf_k = (
......@@ -348,14 +345,11 @@ def _fwd_grouped_kernel_stage1(
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
other=0.0,
).to(reduce_dtype)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
reduce_dtype
)
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
......@@ -363,7 +357,7 @@ def _fwd_grouped_kernel_stage1(
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[None, :] < cur_batch_end_index,
mask=offs_n[None, :] < split_k_end,
other=0.0,
).to(reduce_dtype)
qk += tl.dot(qpe, kpe)
......@@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1(
tl.store(
Att_Out + offs_o,
qk,
mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
)
......@@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
SPLIT_K = 8
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
triton.cdiv(max_len_in_batch, BLOCK),
SPLIT_K,
)
num_warps = 4
......@@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment