Unverified Commit ac1c9342 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix incorrect tiles creation for mm prefix triton attention (#30974)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 4924ac58
...@@ -189,9 +189,14 @@ def kernel_unified_attention_2d( ...@@ -189,9 +189,14 @@ def kernel_unified_attention_2d(
+ 1 + 1
) )
# adjust for potential padding in the last q_block by considering the if USE_MM_PREFIX:
# actual sequence length # image bidirectional attention ranges require a full range
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) # including q_block padding to make sure doc mask is correct
max_seq_prefix_len = tl.maximum(max_seq_prefix_len, seq_len)
else:
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles that need to be processed to # calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond # cover the longest sequence prefix (due to causal masking, tiles beyond
...@@ -202,7 +207,8 @@ def kernel_unified_attention_2d( ...@@ -202,7 +207,8 @@ def kernel_unified_attention_2d(
# Default: keep previous global behavior # Default: keep previous global behavior
tile_start = 0 tile_start = 0
tile_end = num_tiles tile_end = num_tiles
if SLIDING_WINDOW > 0: # TODO(Isotr0py): sliding window pruning with image bidirectional mask
if SLIDING_WINDOW > 0 and not USE_MM_PREFIX:
# Query rows covered by this Q-block # Query rows covered by this Q-block
qpos_lo = q_block_local_idx * BLOCK_Q qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum( qpos_hi = tl.minimum(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment