Unverified Commit 8e6bdf85 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[triton] Support head_dim not 2^n in triton extend and decode attention (#1281)

parent 05bea688
...@@ -60,6 +60,7 @@ def _fwd_kernel_stage1( ...@@ -60,6 +60,7 @@ def _fwd_kernel_stage1(
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
Lk: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -97,7 +98,7 @@ def _fwd_kernel_stage1( ...@@ -97,7 +98,7 @@ def _fwd_kernel_stage1(
) )
k = tl.load( k = tl.load(
K_Buffer + offs_buf_k, K_Buffer + offs_buf_k,
mask=offs_n_new[:, None] < cur_batch_end_index, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
other=0.0, other=0.0,
).to(REDUCE_TRITON_TYPE) ).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
...@@ -128,6 +129,7 @@ def _fwd_kernel_stage2( ...@@ -128,6 +129,7 @@ def _fwd_kernel_stage2(
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
Lv: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -170,14 +172,16 @@ def _fwd_kernel_stage2( ...@@ -170,14 +172,16 @@ def _fwd_kernel_stage2(
old_scale = tl.exp(e_max - n_e_max) old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max) p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0) e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) v = tl.load(
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0) acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max e_max = n_e_max
acc = acc / e_sum acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc) tl.store(out_ptrs, acc, mask=(offs_d < Lv))
def _decode_att_m_fwd( def _decode_att_m_fwd(
...@@ -196,7 +200,7 @@ def _decode_att_m_fwd( ...@@ -196,7 +200,7 @@ def _decode_att_m_fwd(
# shape constraints # shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1] Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256} assert Lk in {16, 32, 64, 96, 128, 256}
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = B_req_idx.shape[0], q.shape[1]
...@@ -208,6 +212,8 @@ def _decode_att_m_fwd( ...@@ -208,6 +212,8 @@ def _decode_att_m_fwd(
else: else:
num_warps = 2 num_warps = 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
_fwd_kernel_stage1[grid]( _fwd_kernel_stage1[grid](
q, q,
k_buffer, k_buffer,
...@@ -224,11 +230,12 @@ def _decode_att_m_fwd( ...@@ -224,11 +230,12 @@ def _decode_att_m_fwd(
k_buffer.stride(1), k_buffer.stride(1),
att_out.stride(0), att_out.stride(0),
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk,
) )
...@@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd( ...@@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd(
num_warps = 1 num_warps = 1
Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logics, logics,
v_buffer, v_buffer,
...@@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd( ...@@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd(
o.stride(1), o.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1], BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=3, num_stages=3,
Lv=Lv,
) )
...@@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
Lk: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1) cur_kv_head = tl.program_id(1)
...@@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1( ...@@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1(
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1): for start_mark in range(0, block_mask, 1):
q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( q = tl.load(
REDUCE_TRITON_TYPE Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
) ).to(REDUCE_TRITON_TYPE)
offs_n_new = cur_batch_start_index + offs_n offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load( k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
...@@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1(
) )
k = tl.load( k = tl.load(
K_Buffer + offs_buf_k, K_Buffer + offs_buf_k,
mask=offs_n_new[None, :] < cur_batch_end_index, mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
other=0.0, other=0.0,
).to(REDUCE_TRITON_TYPE) ).to(REDUCE_TRITON_TYPE)
qk = tl.dot(q, k) qk = tl.dot(q, k)
...@@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2( ...@@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2(
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
Lv: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1) cur_kv_head = tl.program_id(1)
...@@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2( ...@@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2(
old_scale = tl.exp(e_max - n_e_max) old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None]) p = tl.exp(qk - n_e_max[:, None])
e_sum = e_sum * old_scale + tl.sum(p, 1) e_sum = e_sum * old_scale + tl.sum(p, 1)
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) v = tl.load(
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc = acc * old_scale[:, None] + tl.dot(p, v) acc = acc * old_scale[:, None] + tl.dot(p, v)
e_max = n_e_max e_max = n_e_max
...@@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2( ...@@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2(
acc = acc / e_sum[:, None] acc = acc / e_sum[:, None]
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=mask_h[:, None]) tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
def _decode_grouped_att_m_fwd( def _decode_grouped_att_m_fwd(
...@@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd( ...@@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd(
# shape constraints # shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1] Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256, 576} assert Lk in {16, 32, 64, 96, 128, 256, 576}
if Lk == 576: if Lk == 576:
BLOCK_DMODEL = 512 BLOCK_DMODEL = 512
BLOCK_DPE = 64 BLOCK_DPE = 64
else: else:
BLOCK_DMODEL = Lk BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0 BLOCK_DPE = 0
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = B_req_idx.shape[0], q.shape[1]
...@@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd( ...@@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd(
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk,
) )
...@@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd(
num_warps = 8 num_warps = 8
Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_fwd_grouped_kernel_stage2[grid]( _fwd_grouped_kernel_stage2[grid](
logics, logics,
v_buffer, v_buffer,
...@@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd(
req_to_tokens.stride(0), req_to_tokens.stride(0),
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
q_head_num=head_num, q_head_num=head_num,
BLOCK_DMODEL=v_buffer.shape[-1], BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lv=Lv,
) )
......
...@@ -15,7 +15,7 @@ limitations under the License. ...@@ -15,7 +15,7 @@ limitations under the License.
""" """
Memory-efficient attention for prefill. Memory-efficient attention for prefill.
It supporst page size = 1 and prefill with KV cache (i.e. extend). It supports page size = 1 and prefill with KV cache (i.e. extend).
""" """
import torch import torch
...@@ -67,6 +67,8 @@ def _fwd_kernel( ...@@ -67,6 +67,8 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -86,13 +88,18 @@ def _fwd_kernel( ...@@ -86,13 +88,18 @@ def _fwd_kernel(
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
offs_q = ( offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs * stride_qbs
+ cur_head * stride_qh + cur_head * stride_qh
+ offs_d[None, :] + offs_d[None, :]
) )
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) q = tl.load(
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
...@@ -125,7 +132,9 @@ def _fwd_kernel( ...@@ -125,7 +132,9 @@ def _fwd_kernel(
+ cur_kv_head * stride_buf_kh + cur_kv_head * stride_buf_kh
+ offs_d[:, None] + offs_d[:, None]
) )
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) k = tl.load(
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k) qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
...@@ -157,7 +166,9 @@ def _fwd_kernel( ...@@ -157,7 +166,9 @@ def _fwd_kernel(
+ cur_kv_head * stride_buf_vh + cur_kv_head * stride_buf_vh
+ offs_dv[None, :] + offs_dv[None, :]
) )
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) v = tl.load(
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v) acc = acc * re_scale[:, None] + tl.dot(p, v)
...@@ -176,7 +187,9 @@ def _fwd_kernel( ...@@ -176,7 +187,9 @@ def _fwd_kernel(
+ cur_kv_head * stride_kh + cur_kv_head * stride_kh
+ offs_d[:, None] + offs_d[:, None]
) )
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q, k, out_dtype=tl.float32) qk = tl.dot(q, k, out_dtype=tl.float32)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
...@@ -214,7 +227,9 @@ def _fwd_kernel( ...@@ -214,7 +227,9 @@ def _fwd_kernel(
+ cur_kv_head * stride_vh + cur_kv_head * stride_vh
+ offs_dv[None, :] + offs_dv[None, :]
) )
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v) acc = acc * re_scale[:, None] + tl.dot(p, v)
...@@ -226,7 +241,9 @@ def _fwd_kernel( ...@@ -226,7 +241,9 @@ def _fwd_kernel(
+ cur_head * stride_oh + cur_head * stride_oh
+ offs_dv[None, :] + offs_dv[None, :]
) )
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) tl.store(
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
)
def extend_attention_fwd( def extend_attention_fwd(
...@@ -261,16 +278,18 @@ def extend_attention_fwd( ...@@ -261,16 +278,18 @@ def extend_attention_fwd(
) )
assert Lq == Lk and Lv == Lo assert Lq == Lk and Lv == Lo
assert Lq in {16, 32, 64, 128, 256, 576}
assert Lv in {16, 32, 64, 128, 256, 512} # TODO: is the assertion necessary?
assert Lq in {16, 32, 64, 96, 128, 256, 576}
assert Lv in {16, 32, 64, 96, 128, 256, 512}
if Lq == 576: if Lq == 576:
BLOCK_DMODEL = 512 BLOCK_DMODEL = 512
BLOCK_DPE = 64 BLOCK_DPE = 64
else: else:
BLOCK_DMODEL = Lq BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = Lv BLOCK_DV = triton.next_power_of_2(Lv)
if CUDA_CAPABILITY[0] >= 9: if CUDA_CAPABILITY[0] >= 9:
if Lq <= 256: if Lq <= 256:
...@@ -330,6 +349,8 @@ def extend_attention_fwd( ...@@ -330,6 +349,8 @@ def extend_attention_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
logit_cap=logit_cap, logit_cap=logit_cap,
Lq=Lq,
Lv=Lv,
) )
...@@ -373,10 +394,7 @@ def redundant_attention( ...@@ -373,10 +394,7 @@ def redundant_attention(
pt += cur_seq_len_extend pt += cur_seq_len_extend
def test(): def test_once(B, N_CTX, H_Q, H_KV, D):
torch.manual_seed(0)
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
dtype = torch.float16 dtype = torch.float16
b_seq_len_prefix = torch.randint( b_seq_len_prefix = torch.randint(
...@@ -473,4 +491,5 @@ def test(): ...@@ -473,4 +491,5 @@ def test():
if __name__ == "__main__": if __name__ == "__main__":
test() test_once(19, 12331, 12, 4, 128)
test_once(19, 12331, 12, 4, 96)
...@@ -48,6 +48,7 @@ def _fwd_kernel( ...@@ -48,6 +48,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
Lk: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -72,7 +73,11 @@ def _fwd_kernel( ...@@ -72,7 +73,11 @@ def _fwd_kernel(
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) mask_d = offs_d < Lk
q = tl.load(
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
)
k_ptrs = K + off_k k_ptrs = K + off_k
v_ptrs = V + off_v v_ptrs = V + off_v
...@@ -89,7 +94,7 @@ def _fwd_kernel( ...@@ -89,7 +94,7 @@ def _fwd_kernel(
# -- compute qk ---- # -- compute qk ----
k = tl.load( k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
other=0.0, other=0.0,
) )
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
...@@ -118,7 +123,7 @@ def _fwd_kernel( ...@@ -118,7 +123,7 @@ def _fwd_kernel(
# update acc # update acc
v = tl.load( v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0, other=0.0,
) )
...@@ -134,7 +139,9 @@ def _fwd_kernel( ...@@ -134,7 +139,9 @@ def _fwd_kernel(
+ offs_d[None, :] + offs_d[None, :]
) )
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
)
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
...@@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128, 256} assert Lk in {16, 32, 64, 96, 128, 256}
sm_scale = 1.0 / (Lq**0.5) sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1] batch, head = b_seq_len.shape[0], q.shape[1]
...@@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
o.stride(1), o.stride(1),
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment