Unverified Commit ddcf9fe3 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize triton attention custom mask (#3731)

parent 6252ade9
......@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
):
cur_seq = tl.program_id(0)
......@@ -160,7 +161,7 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if USE_CUSTOM_MASK:
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
......@@ -302,6 +303,7 @@ def extend_attention_fwd(
max_len_extend,
sm_scale=None,
logit_cap=0.0,
skip_prefix_custom_mask=True,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -355,6 +357,8 @@ def extend_attention_fwd(
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
USE_CUSTOM_MASK = custom_mask is not None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
......@@ -398,6 +402,7 @@ def extend_attention_fwd(
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=is_hip_,
num_warps=num_warps,
num_stages=num_stages,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment