Unverified Commit ddcf9fe3 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize triton attention custom mask (#3731)

parent 6252ade9
...@@ -74,6 +74,7 @@ def _fwd_kernel( ...@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr, USE_CUSTOM_MASK: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr, STORE_TRANSPOSE: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
...@@ -160,7 +161,7 @@ def _fwd_kernel( ...@@ -160,7 +161,7 @@ def _fwd_kernel(
if logit_cap > 0: if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap) qk = logit_cap * tanh(qk / logit_cap)
if USE_CUSTOM_MASK: if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
custom_mask = tl.load( custom_mask = tl.load(
mask_ptr mask_ptr
+ cur_seq_mask_start_idx + cur_seq_mask_start_idx
...@@ -302,6 +303,7 @@ def extend_attention_fwd( ...@@ -302,6 +303,7 @@ def extend_attention_fwd(
max_len_extend, max_len_extend,
sm_scale=None, sm_scale=None,
logit_cap=0.0, logit_cap=0.0,
skip_prefix_custom_mask=True,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
...@@ -355,6 +357,8 @@ def extend_attention_fwd( ...@@ -355,6 +357,8 @@ def extend_attention_fwd(
kv_group_num = q_extend.shape[1] // k_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1]
USE_CUSTOM_MASK = custom_mask is not None USE_CUSTOM_MASK = custom_mask is not None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1 num_stages = 1
...@@ -398,6 +402,7 @@ def extend_attention_fwd( ...@@ -398,6 +402,7 @@ def extend_attention_fwd(
Lq=Lq, Lq=Lq,
Lv=Lv, Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK, USE_CUSTOM_MASK=USE_CUSTOM_MASK,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=is_hip_, STORE_TRANSPOSE=is_hip_,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment